Skip to content

Commit 57f2a4e

Browse files
committed
Support Gemini 2.5 Flash thinking
1 parent c329b45 commit 57f2a4e

File tree

7 files changed

+186
-221
lines changed

7 files changed

+186
-221
lines changed

package-lock.json

Lines changed: 30 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@
405405
"@anthropic-ai/vertex-sdk": "^0.7.0",
406406
"@aws-sdk/client-bedrock-runtime": "^3.779.0",
407407
"@google-cloud/vertexai": "^1.9.3",
408-
"@google/generative-ai": "^0.18.0",
408+
"@google/genai": "^0.9.0",
409409
"@mistralai/mistralai": "^1.3.6",
410410
"@modelcontextprotocol/sdk": "^1.7.0",
411411
"@types/clone-deep": "^4.0.4",

src/api/providers/__tests__/gemini.test.ts

Lines changed: 45 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,39 @@
1-
import { GeminiHandler } from "../gemini"
1+
// npx jest src/api/providers/__tests__/gemini.test.ts
2+
23
import { Anthropic } from "@anthropic-ai/sdk"
3-
import { GoogleGenerativeAI } from "@google/generative-ai"
4-
5-
// Mock the Google Generative AI SDK
6-
jest.mock("@google/generative-ai", () => ({
7-
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
8-
getGenerativeModel: jest.fn().mockReturnValue({
9-
generateContentStream: jest.fn(),
10-
generateContent: jest.fn().mockResolvedValue({
11-
response: {
12-
text: () => "Test response",
13-
},
14-
}),
15-
}),
16-
})),
17-
}))
4+
5+
import { GeminiHandler } from "../gemini"
186

197
describe("GeminiHandler", () => {
208
let handler: GeminiHandler
219

2210
beforeEach(() => {
11+
// Create mock functions
12+
const mockGenerateContentStream = jest.fn()
13+
const mockGenerateContent = jest.fn()
14+
const mockGetGenerativeModel = jest.fn()
15+
2316
handler = new GeminiHandler({
2417
apiKey: "test-key",
2518
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
2619
geminiApiKey: "test-key",
2720
})
21+
22+
// Replace the client with our mock
23+
handler.client = {
24+
models: {
25+
generateContentStream: mockGenerateContentStream,
26+
generateContent: mockGenerateContent,
27+
getGenerativeModel: mockGetGenerativeModel,
28+
},
29+
} as any
2830
})
2931

3032
describe("constructor", () => {
3133
it("should initialize with provided config", () => {
3234
expect(handler["options"].geminiApiKey).toBe("test-key")
3335
expect(handler["options"].apiModelId).toBe("gemini-2.0-flash-thinking-exp-1219")
3436
})
35-
36-
it.skip("should throw if API key is missing", () => {
37-
expect(() => {
38-
new GeminiHandler({
39-
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
40-
geminiApiKey: "",
41-
})
42-
}).toThrow("API key is required for Google Gemini")
43-
})
4437
})
4538

4639
describe("createMessage", () => {
@@ -58,25 +51,15 @@ describe("GeminiHandler", () => {
5851
const systemPrompt = "You are a helpful assistant"
5952

6053
it("should handle text messages correctly", async () => {
61-
// Mock the stream response
62-
const mockStream = {
63-
stream: [{ text: () => "Hello" }, { text: () => " world!" }],
64-
response: {
65-
usageMetadata: {
66-
promptTokenCount: 10,
67-
candidatesTokenCount: 5,
68-
},
54+
// Setup the mock implementation to return an async generator
55+
;(handler.client.models.generateContentStream as jest.Mock).mockResolvedValue({
56+
[Symbol.asyncIterator]: async function* () {
57+
yield { text: "Hello" }
58+
yield { text: " world!" }
59+
yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } }
6960
},
70-
}
71-
72-
// Setup the mock implementation
73-
const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream)
74-
const mockGetGenerativeModel = jest.fn().mockReturnValue({
75-
generateContentStream: mockGenerateContentStream,
7661
})
7762

78-
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
79-
8063
const stream = handler.createMessage(systemPrompt, mockMessages)
8164
const chunks = []
8265

@@ -100,99 +83,67 @@ describe("GeminiHandler", () => {
10083
outputTokens: 5,
10184
})
10285

103-
// Verify the model configuration
104-
expect(mockGetGenerativeModel).toHaveBeenCalledWith(
105-
{
106-
model: "gemini-2.0-flash-thinking-exp-1219",
107-
systemInstruction: systemPrompt,
108-
},
109-
{
110-
baseUrl: undefined,
111-
},
112-
)
113-
114-
// Verify generation config
115-
expect(mockGenerateContentStream).toHaveBeenCalledWith(
86+
// Verify the call to generateContentStream
87+
expect(handler.client.models.generateContentStream).toHaveBeenCalledWith(
11688
expect.objectContaining({
117-
generationConfig: {
89+
model: "gemini-2.0-flash-thinking-exp-1219",
90+
config: expect.objectContaining({
11891
temperature: 0,
119-
},
92+
systemInstruction: systemPrompt,
93+
}),
12094
}),
12195
)
12296
})
12397

12498
it("should handle API errors", async () => {
12599
const mockError = new Error("Gemini API error")
126-
const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError)
127-
const mockGetGenerativeModel = jest.fn().mockReturnValue({
128-
generateContentStream: mockGenerateContentStream,
129-
})
130-
131-
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
100+
;(handler.client.models.generateContentStream as jest.Mock).mockRejectedValue(mockError)
132101

133102
const stream = handler.createMessage(systemPrompt, mockMessages)
134103

135104
await expect(async () => {
136105
for await (const chunk of stream) {
137106
// Should throw before yielding any chunks
138107
}
139-
}).rejects.toThrow("Gemini API error")
108+
}).rejects.toThrow()
140109
})
141110
})
142111

143112
describe("completePrompt", () => {
144113
it("should complete prompt successfully", async () => {
145-
const mockGenerateContent = jest.fn().mockResolvedValue({
146-
response: {
147-
text: () => "Test response",
148-
},
114+
// Mock the response with text property
115+
;(handler.client.models.generateContent as jest.Mock).mockResolvedValue({
116+
text: "Test response",
149117
})
150-
const mockGetGenerativeModel = jest.fn().mockReturnValue({
151-
generateContent: mockGenerateContent,
152-
})
153-
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
154118

155119
const result = await handler.completePrompt("Test prompt")
156120
expect(result).toBe("Test response")
157-
expect(mockGetGenerativeModel).toHaveBeenCalledWith(
158-
{
159-
model: "gemini-2.0-flash-thinking-exp-1219",
160-
},
161-
{
162-
baseUrl: undefined,
163-
},
164-
)
165-
expect(mockGenerateContent).toHaveBeenCalledWith({
121+
122+
// Verify the call to generateContent
123+
expect(handler.client.models.generateContent).toHaveBeenCalledWith({
124+
model: "gemini-2.0-flash-thinking-exp-1219",
166125
contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
167-
generationConfig: {
126+
config: {
127+
httpOptions: undefined,
168128
temperature: 0,
169129
},
170130
})
171131
})
172132

173133
it("should handle API errors", async () => {
174134
const mockError = new Error("Gemini API error")
175-
const mockGenerateContent = jest.fn().mockRejectedValue(mockError)
176-
const mockGetGenerativeModel = jest.fn().mockReturnValue({
177-
generateContent: mockGenerateContent,
178-
})
179-
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
135+
;(handler.client.models.generateContent as jest.Mock).mockRejectedValue(mockError)
180136

181137
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
182138
"Gemini completion error: Gemini API error",
183139
)
184140
})
185141

186142
it("should handle empty response", async () => {
187-
const mockGenerateContent = jest.fn().mockResolvedValue({
188-
response: {
189-
text: () => "",
190-
},
191-
})
192-
const mockGetGenerativeModel = jest.fn().mockReturnValue({
193-
generateContent: mockGenerateContent,
143+
// Mock the response with empty text
144+
;(handler.client.models.generateContent as jest.Mock).mockResolvedValue({
145+
text: "",
194146
})
195-
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
196147

197148
const result = await handler.completePrompt("Test prompt")
198149
expect(result).toBe("")

src/api/providers/anthropic.ts

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
2323

2424
const apiKeyFieldName =
2525
this.options.anthropicBaseUrl && this.options.anthropicUseAuthToken ? "authToken" : "apiKey"
26+
2627
this.client = new Anthropic({
2728
baseURL: this.options.anthropicBaseUrl || undefined,
2829
[apiKeyFieldName]: this.options.apiKey,
@@ -217,10 +218,10 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
217218
}
218219

219220
async completePrompt(prompt: string) {
220-
let { id: modelId, temperature } = this.getModel()
221+
let { id: model, temperature } = this.getModel()
221222

222223
const message = await this.client.messages.create({
223-
model: modelId,
224+
model,
224225
max_tokens: ANTHROPIC_DEFAULT_MAX_TOKENS,
225226
thinking: undefined,
226227
temperature,
@@ -241,16 +242,11 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
241242
override async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
242243
try {
243244
// Use the current model
244-
const actualModelId = this.getModel().id
245+
const { id: model } = this.getModel()
245246

246247
const response = await this.client.messages.countTokens({
247-
model: actualModelId,
248-
messages: [
249-
{
250-
role: "user",
251-
content: content,
252-
},
253-
],
248+
model,
249+
messages: [{ role: "user", content: content }],
254250
})
255251

256252
return response.input_tokens

0 commit comments

Comments
 (0)