Skip to content

Commit 5abea50

Browse files
authored
Support Gemini 2.5 Flash thinking (#2752)
1 parent 6772306 commit 5abea50

File tree

12 files changed

+305
-243
lines changed

12 files changed

+305
-243
lines changed

.changeset/shiny-poems-search.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"roo-cline": patch
3+
---
4+
5+
Support Gemini 2.5 Flash thinking mode

package-lock.json

Lines changed: 30 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@
405405
"@anthropic-ai/vertex-sdk": "^0.7.0",
406406
"@aws-sdk/client-bedrock-runtime": "^3.779.0",
407407
"@google-cloud/vertexai": "^1.9.3",
408-
"@google/generative-ai": "^0.18.0",
408+
"@google/genai": "^0.9.0",
409409
"@mistralai/mistralai": "^1.3.6",
410410
"@modelcontextprotocol/sdk": "^1.7.0",
411411
"@types/clone-deep": "^4.0.4",
Lines changed: 52 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,45 +1,41 @@
1-
import { GeminiHandler } from "../gemini"
1+
// npx jest src/api/providers/__tests__/gemini.test.ts
2+
23
import { Anthropic } from "@anthropic-ai/sdk"
3-
import { GoogleGenerativeAI } from "@google/generative-ai"
4-
5-
// Mock the Google Generative AI SDK
6-
jest.mock("@google/generative-ai", () => ({
7-
GoogleGenerativeAI: jest.fn().mockImplementation(() => ({
8-
getGenerativeModel: jest.fn().mockReturnValue({
9-
generateContentStream: jest.fn(),
10-
generateContent: jest.fn().mockResolvedValue({
11-
response: {
12-
text: () => "Test response",
13-
},
14-
}),
15-
}),
16-
})),
17-
}))
4+
5+
import { GeminiHandler } from "../gemini"
6+
import { geminiDefaultModelId } from "../../../shared/api"
7+
8+
const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219"
189

1910
describe("GeminiHandler", () => {
2011
let handler: GeminiHandler
2112

2213
beforeEach(() => {
14+
// Create mock functions
15+
const mockGenerateContentStream = jest.fn()
16+
const mockGenerateContent = jest.fn()
17+
const mockGetGenerativeModel = jest.fn()
18+
2319
handler = new GeminiHandler({
2420
apiKey: "test-key",
25-
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
21+
apiModelId: GEMINI_20_FLASH_THINKING_NAME,
2622
geminiApiKey: "test-key",
2723
})
24+
25+
// Replace the client with our mock
26+
handler["client"] = {
27+
models: {
28+
generateContentStream: mockGenerateContentStream,
29+
generateContent: mockGenerateContent,
30+
getGenerativeModel: mockGetGenerativeModel,
31+
},
32+
} as any
2833
})
2934

3035
describe("constructor", () => {
3136
it("should initialize with provided config", () => {
3237
expect(handler["options"].geminiApiKey).toBe("test-key")
33-
expect(handler["options"].apiModelId).toBe("gemini-2.0-flash-thinking-exp-1219")
34-
})
35-
36-
it.skip("should throw if API key is missing", () => {
37-
expect(() => {
38-
new GeminiHandler({
39-
apiModelId: "gemini-2.0-flash-thinking-exp-1219",
40-
geminiApiKey: "",
41-
})
42-
}).toThrow("API key is required for Google Gemini")
38+
expect(handler["options"].apiModelId).toBe(GEMINI_20_FLASH_THINKING_NAME)
4339
})
4440
})
4541

@@ -58,25 +54,15 @@ describe("GeminiHandler", () => {
5854
const systemPrompt = "You are a helpful assistant"
5955

6056
it("should handle text messages correctly", async () => {
61-
// Mock the stream response
62-
const mockStream = {
63-
stream: [{ text: () => "Hello" }, { text: () => " world!" }],
64-
response: {
65-
usageMetadata: {
66-
promptTokenCount: 10,
67-
candidatesTokenCount: 5,
68-
},
57+
// Setup the mock implementation to return an async generator
58+
;(handler["client"].models.generateContentStream as jest.Mock).mockResolvedValue({
59+
[Symbol.asyncIterator]: async function* () {
60+
yield { text: "Hello" }
61+
yield { text: " world!" }
62+
yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } }
6963
},
70-
}
71-
72-
// Setup the mock implementation
73-
const mockGenerateContentStream = jest.fn().mockResolvedValue(mockStream)
74-
const mockGetGenerativeModel = jest.fn().mockReturnValue({
75-
generateContentStream: mockGenerateContentStream,
7664
})
7765

78-
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
79-
8066
const stream = handler.createMessage(systemPrompt, mockMessages)
8167
const chunks = []
8268

@@ -100,99 +86,67 @@ describe("GeminiHandler", () => {
10086
outputTokens: 5,
10187
})
10288

103-
// Verify the model configuration
104-
expect(mockGetGenerativeModel).toHaveBeenCalledWith(
105-
{
106-
model: "gemini-2.0-flash-thinking-exp-1219",
107-
systemInstruction: systemPrompt,
108-
},
109-
{
110-
baseUrl: undefined,
111-
},
112-
)
113-
114-
// Verify generation config
115-
expect(mockGenerateContentStream).toHaveBeenCalledWith(
89+
// Verify the call to generateContentStream
90+
expect(handler["client"].models.generateContentStream).toHaveBeenCalledWith(
11691
expect.objectContaining({
117-
generationConfig: {
92+
model: GEMINI_20_FLASH_THINKING_NAME,
93+
config: expect.objectContaining({
11894
temperature: 0,
119-
},
95+
systemInstruction: systemPrompt,
96+
}),
12097
}),
12198
)
12299
})
123100

124101
it("should handle API errors", async () => {
125102
const mockError = new Error("Gemini API error")
126-
const mockGenerateContentStream = jest.fn().mockRejectedValue(mockError)
127-
const mockGetGenerativeModel = jest.fn().mockReturnValue({
128-
generateContentStream: mockGenerateContentStream,
129-
})
130-
131-
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
103+
;(handler["client"].models.generateContentStream as jest.Mock).mockRejectedValue(mockError)
132104

133105
const stream = handler.createMessage(systemPrompt, mockMessages)
134106

135107
await expect(async () => {
136108
for await (const chunk of stream) {
137109
// Should throw before yielding any chunks
138110
}
139-
}).rejects.toThrow("Gemini API error")
111+
}).rejects.toThrow()
140112
})
141113
})
142114

143115
describe("completePrompt", () => {
144116
it("should complete prompt successfully", async () => {
145-
const mockGenerateContent = jest.fn().mockResolvedValue({
146-
response: {
147-
text: () => "Test response",
148-
},
149-
})
150-
const mockGetGenerativeModel = jest.fn().mockReturnValue({
151-
generateContent: mockGenerateContent,
117+
// Mock the response with text property
118+
;(handler["client"].models.generateContent as jest.Mock).mockResolvedValue({
119+
text: "Test response",
152120
})
153-
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
154121

155122
const result = await handler.completePrompt("Test prompt")
156123
expect(result).toBe("Test response")
157-
expect(mockGetGenerativeModel).toHaveBeenCalledWith(
158-
{
159-
model: "gemini-2.0-flash-thinking-exp-1219",
160-
},
161-
{
162-
baseUrl: undefined,
163-
},
164-
)
165-
expect(mockGenerateContent).toHaveBeenCalledWith({
124+
125+
// Verify the call to generateContent
126+
expect(handler["client"].models.generateContent).toHaveBeenCalledWith({
127+
model: GEMINI_20_FLASH_THINKING_NAME,
166128
contents: [{ role: "user", parts: [{ text: "Test prompt" }] }],
167-
generationConfig: {
129+
config: {
130+
httpOptions: undefined,
168131
temperature: 0,
169132
},
170133
})
171134
})
172135

173136
it("should handle API errors", async () => {
174137
const mockError = new Error("Gemini API error")
175-
const mockGenerateContent = jest.fn().mockRejectedValue(mockError)
176-
const mockGetGenerativeModel = jest.fn().mockReturnValue({
177-
generateContent: mockGenerateContent,
178-
})
179-
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
138+
;(handler["client"].models.generateContent as jest.Mock).mockRejectedValue(mockError)
180139

181140
await expect(handler.completePrompt("Test prompt")).rejects.toThrow(
182141
"Gemini completion error: Gemini API error",
183142
)
184143
})
185144

186145
it("should handle empty response", async () => {
187-
const mockGenerateContent = jest.fn().mockResolvedValue({
188-
response: {
189-
text: () => "",
190-
},
191-
})
192-
const mockGetGenerativeModel = jest.fn().mockReturnValue({
193-
generateContent: mockGenerateContent,
146+
// Mock the response with empty text
147+
;(handler["client"].models.generateContent as jest.Mock).mockResolvedValue({
148+
text: "",
194149
})
195-
;(handler["client"] as any).getGenerativeModel = mockGetGenerativeModel
196150

197151
const result = await handler.completePrompt("Test prompt")
198152
expect(result).toBe("")
@@ -202,7 +156,7 @@ describe("GeminiHandler", () => {
202156
describe("getModel", () => {
203157
it("should return correct model info", () => {
204158
const modelInfo = handler.getModel()
205-
expect(modelInfo.id).toBe("gemini-2.0-flash-thinking-exp-1219")
159+
expect(modelInfo.id).toBe(GEMINI_20_FLASH_THINKING_NAME)
206160
expect(modelInfo.info).toBeDefined()
207161
expect(modelInfo.info.maxTokens).toBe(8192)
208162
expect(modelInfo.info.contextWindow).toBe(32_767)
@@ -214,7 +168,7 @@ describe("GeminiHandler", () => {
214168
geminiApiKey: "test-key",
215169
})
216170
const modelInfo = invalidHandler.getModel()
217-
expect(modelInfo.id).toBe("gemini-2.0-flash-001") // Default model
171+
expect(modelInfo.id).toBe(geminiDefaultModelId) // Default model
218172
})
219173
})
220174
})

src/api/providers/anthropic.ts

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
2323

2424
const apiKeyFieldName =
2525
this.options.anthropicBaseUrl && this.options.anthropicUseAuthToken ? "authToken" : "apiKey"
26+
2627
this.client = new Anthropic({
2728
baseURL: this.options.anthropicBaseUrl || undefined,
2829
[apiKeyFieldName]: this.options.apiKey,
@@ -217,10 +218,10 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
217218
}
218219

219220
async completePrompt(prompt: string) {
220-
let { id: modelId, temperature } = this.getModel()
221+
let { id: model, temperature } = this.getModel()
221222

222223
const message = await this.client.messages.create({
223-
model: modelId,
224+
model,
224225
max_tokens: ANTHROPIC_DEFAULT_MAX_TOKENS,
225226
thinking: undefined,
226227
temperature,
@@ -241,16 +242,11 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
241242
override async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
242243
try {
243244
// Use the current model
244-
const actualModelId = this.getModel().id
245+
const { id: model } = this.getModel()
245246

246247
const response = await this.client.messages.countTokens({
247-
model: actualModelId,
248-
messages: [
249-
{
250-
role: "user",
251-
content: content,
252-
},
253-
],
248+
model,
249+
messages: [{ role: "user", content: content }],
254250
})
255251

256252
return response.input_tokens

0 commit comments

Comments
 (0)