Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@
"@anthropic-ai/vertex-sdk": "^0.7.0",
"@aws-sdk/client-bedrock-runtime": "^3.779.0",
"@google-cloud/vertexai": "^1.9.3",
"@google/generative-ai": "^0.18.0",
"@google/genai": "^0.9.0",
"@mistralai/mistralai": "^1.3.6",
"@modelcontextprotocol/sdk": "^1.7.0",
"@types/clone-deep": "^4.0.4",
Expand Down
139 changes: 45 additions & 94 deletions src/api/providers/__tests__/gemini.test.ts
Original file line number Diff line number Diff line change
@@ -1,46 +1,39 @@
import { GeminiHandler } from "../gemini"
// npx jest src/api/providers/__tests__/gemini.test.ts

import { Anthropic } from "@anthropic-ai/sdk"
import { GoogleGenerativeAI } from "@google/generative-ai"

// Mock the Google Generative AI SDK
jest.mock("@google/generative-ai", () => ({
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
getGenerativeModel: jest.fn().mockReturnValue({
generateContentStream: jest.fn(),
generateContent: jest.fn().mockResolvedValue({
response: {
text: () => "Test response",
},
}),
}),
})),
}))

import { GeminiHandler } from "../gemini"

describe("GeminiHandler", () => {
let handler: GeminiHandler

beforeEach(() => {
// Create mock functions
const mockGenerateContentStream = jest.fn()
const mockGenerateContent = jest.fn()
const mockGetGenerativeModel = jest.fn()

handler = new GeminiHandler({
apiKey: "test-key",
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
geminiApiKey: "test-key",
})

// Replace the client with our mock
handler.client = {
models: {
generateContentStream: mockGenerateContentStream,
generateContent: mockGenerateContent,
getGenerativeModel: mockGetGenerativeModel,
},
} as any
})

describe("constructor", () => {
it("should initialize with provided config", () => {
expect(handler["options"].geminiApiKey).toBe("test-key")
expect(handler["options"].apiModelId).toBe("gemini-2.0-flash-thinking-exp-1219")
})

it.skip("should throw if API key is missing", () => {
expect(() => {
new GeminiHandler({
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
geminiApiKey: "",
})
}).toThrow("API key is required for Google Gemini")
})
})

describe("createMessage", () => {
Expand All @@ -58,25 +51,15 @@ describe("GeminiHandler", () => {
const systemPrompt = "You are a helpful assistant"

it("should handle text messages correctly", async () => {
// Mock the stream response
const mockStream = {
stream: [{ text: () => "Hello" }, { text: () => " world!" }],
response: {
usageMetadata: {
promptTokenCount: 10,
candidatesTokenCount: 5,
},
// Setup the mock implementation to return an async generator
;(handler.client.models.generateContentStream as jest.Mock).mockResolvedValue({
[Symbol.asyncIterator]: async function* () {
yield { text: "Hello" }
yield { text: " world!" }
yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } }
},
}

// Setup the mock implementation
const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream)
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContentStream: mockGenerateContentStream,
})

;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel

const stream = handler.createMessage(systemPrompt, mockMessages)
const chunks = []

Expand All @@ -100,99 +83,67 @@ describe("GeminiHandler", () => {
outputTokens: 5,
})

// Verify the model configuration
expect(mockGetGenerativeModel).toHaveBeenCalledWith(
{
model: "gemini-2.0-flash-thinking-exp-1219",
systemInstruction: systemPrompt,
},
{
baseUrl: undefined,
},
)

// Verify generation config
expect(mockGenerateContentStream).toHaveBeenCalledWith(
// Verify the call to generateContentStream
expect(handler.client.models.generateContentStream).toHaveBeenCalledWith(
expect.objectContaining({
generationConfig: {
model: "gemini-2.0-flash-thinking-exp-1219",
config: expect.objectContaining({
temperature: 0,
},
systemInstruction: systemPrompt,
}),
}),
)
})

it("should handle API errors", async () => {
const mockError = new Error("Gemini API error")
const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError)
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContentStream: mockGenerateContentStream,
})

;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
;(handler.client.models.generateContentStream as jest.Mock).mockRejectedValue(mockError)

const stream = handler.createMessage(systemPrompt, mockMessages)

await expect(async () => {
for await (const chunk of stream) {
// Should throw before yielding any chunks
}
}).rejects.toThrow("Gemini API error")
}).rejects.toThrow()
})
})

describe("completePrompt", () => {
it("should complete prompt successfully", async () => {
const mockGenerateContent = jest.fn().mockResolvedValue({
response: {
text: () => "Test response",
},
// Mock the response with text property
;(handler.client.models.generateContent as jest.Mock).mockResolvedValue({
text: "Test response",
})
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent,
})
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel

const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockGetGenerativeModel).toHaveBeenCalledWith(
{
model: "gemini-2.0-flash-thinking-exp-1219",
},
{
baseUrl: undefined,
},
)
expect(mockGenerateContent).toHaveBeenCalledWith({

// Verify the call to generateContent
expect(handler.client.models.generateContent).toHaveBeenCalledWith({
model: "gemini-2.0-flash-thinking-exp-1219",
contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
generationConfig: {
config: {
httpOptions: undefined,
temperature: 0,
},
})
})

it("should handle API errors", async () => {
const mockError = new Error("Gemini API error")
const mockGenerateContent = jest.fn().mockRejectedValue(mockError)
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent,
})
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
;(handler.client.models.generateContent as jest.Mock).mockRejectedValue(mockError)

await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
"Gemini completion error: Gemini API error",
)
})

it("should handle empty response", async () => {
const mockGenerateContent = jest.fn().mockResolvedValue({
response: {
text: () => "",
},
})
const mockGetGenerativeModel = jest.fn().mockReturnValue({
generateContent: mockGenerateContent,
// Mock the response with empty text
;(handler.client.models.generateContent as jest.Mock).mockResolvedValue({
text: "",
})
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel

const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
Expand Down
16 changes: 6 additions & 10 deletions src/api/providers/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa

const apiKeyFieldName =
this.options.anthropicBaseUrl && this.options.anthropicUseAuthToken ? "authToken" : "apiKey"

this.client = new Anthropic({
baseURL: this.options.anthropicBaseUrl || undefined,
[apiKeyFieldName]: this.options.apiKey,
Expand Down Expand Up @@ -217,10 +218,10 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
}

async completePrompt(prompt: string) {
let { id: modelId, temperature } = this.getModel()
let { id: model, temperature } = this.getModel()

const message = await this.client.messages.create({
model: modelId,
model,
max_tokens: ANTHROPIC_DEFAULT_MAX_TOKENS,
thinking: undefined,
temperature,
Expand All @@ -241,16 +242,11 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
override async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
try {
// Use the current model
const actualModelId = this.getModel().id
const { id: model } = this.getModel()

const response = await this.client.messages.countTokens({
model: actualModelId,
messages: [
{
role: "user",
content: content,
},
],
model,
messages: [{ role: "user", content: content }],
})

return response.input_tokens
Expand Down
Loading