Skip to content

Commit 18995aa

Browse files
committed
fix: correct Gemini countTokens payload format to use Content[] instead of Part[]
- Fixed countTokens method to wrap parts in proper Content structure with user role - Added comprehensive tests for countTokens including multimodal content and fallback scenarios - This resolves premature context truncation due to incorrect token counting Fixes #8113
1 parent 87b45de commit 18995aa

File tree

2 files changed

+138
-1
lines changed

2 files changed

+138
-1
lines changed

src/api/providers/__tests__/gemini.spec.ts

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import { type ModelInfo, geminiDefaultModelId } from "@roo-code/types"
66

77
import { t } from "i18next"
88
import { GeminiHandler } from "../gemini"
9+
import { BaseProvider } from "../base-provider"
910

1011
const GEMINI_20_FLASH_THINKING_NAME = "gemini-2.0-flash-thinking-exp-1219"
1112

@@ -248,4 +249,138 @@ describe("GeminiHandler", () => {
248249
expect(cost).toBeUndefined()
249250
})
250251
})
252+
253+
describe("countTokens", () => {
254+
it("should count tokens successfully with correct Content[] format", async () => {
255+
// Mock the countTokens response
256+
const mockCountTokens = vitest.fn().mockResolvedValue({
257+
totalTokens: 42,
258+
})
259+
260+
handler["client"].models.countTokens = mockCountTokens
261+
262+
const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Hello world" }]
263+
264+
const result = await handler.countTokens(content)
265+
expect(result).toBe(42)
266+
267+
// Verify the call was made with correct Content[] format
268+
expect(mockCountTokens).toHaveBeenCalledWith({
269+
model: GEMINI_20_FLASH_THINKING_NAME,
270+
contents: [
271+
{
272+
role: "user",
273+
parts: [{ text: "Hello world" }],
274+
},
275+
],
276+
})
277+
})
278+
279+
it("should handle multimodal content correctly", async () => {
280+
const mockCountTokens = vitest.fn().mockResolvedValue({
281+
totalTokens: 100,
282+
})
283+
284+
handler["client"].models.countTokens = mockCountTokens
285+
286+
const content: Anthropic.Messages.ContentBlockParam[] = [
287+
{ type: "text", text: "Describe this image:" },
288+
{
289+
type: "image",
290+
source: {
291+
type: "base64",
292+
media_type: "image/jpeg",
293+
data: "base64data",
294+
},
295+
},
296+
]
297+
298+
const result = await handler.countTokens(content)
299+
expect(result).toBe(100)
300+
301+
// Verify the Content[] structure with mixed content
302+
expect(mockCountTokens).toHaveBeenCalledWith({
303+
model: GEMINI_20_FLASH_THINKING_NAME,
304+
contents: [
305+
{
306+
role: "user",
307+
parts: [
308+
{ text: "Describe this image:" },
309+
{ inlineData: { data: "base64data", mimeType: "image/jpeg" } },
310+
],
311+
},
312+
],
313+
})
314+
})
315+
316+
it("should fall back to base provider when SDK returns undefined", async () => {
317+
// Mock countTokens to return undefined totalTokens
318+
const mockCountTokens = vitest.fn().mockResolvedValue({
319+
totalTokens: undefined,
320+
})
321+
322+
handler["client"].models.countTokens = mockCountTokens
323+
324+
// Spy on the parent class method
325+
const superCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens")
326+
327+
const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }]
328+
329+
await handler.countTokens(content)
330+
331+
// Verify fallback was called
332+
expect(superCountTokensSpy).toHaveBeenCalledWith(content)
333+
})
334+
335+
it("should fall back to base provider when SDK throws error", async () => {
336+
// Mock countTokens to throw an error
337+
const mockCountTokens = vitest.fn().mockRejectedValue(new Error("API error"))
338+
339+
handler["client"].models.countTokens = mockCountTokens
340+
341+
// Spy on console.warn
342+
const consoleWarnSpy = vitest.spyOn(console, "warn").mockImplementation(() => {})
343+
344+
// Spy on the parent class method
345+
const superCountTokensSpy = vitest.spyOn(BaseProvider.prototype, "countTokens")
346+
347+
const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }]
348+
349+
await handler.countTokens(content)
350+
351+
// Verify warning was logged
352+
expect(consoleWarnSpy).toHaveBeenCalledWith(
353+
"Gemini token counting failed, using fallback",
354+
expect.any(Error),
355+
)
356+
357+
// Verify fallback was called
358+
expect(superCountTokensSpy).toHaveBeenCalledWith(content)
359+
360+
// Clean up
361+
consoleWarnSpy.mockRestore()
362+
})
363+
364+
it("should handle empty content array", async () => {
365+
const mockCountTokens = vitest.fn().mockResolvedValue({
366+
totalTokens: 0,
367+
})
368+
369+
handler["client"].models.countTokens = mockCountTokens
370+
371+
const result = await handler.countTokens([])
372+
expect(result).toBe(0)
373+
374+
// Verify the call with empty parts
375+
expect(mockCountTokens).toHaveBeenCalledWith({
376+
model: GEMINI_20_FLASH_THINKING_NAME,
377+
contents: [
378+
{
379+
role: "user",
380+
parts: [],
381+
},
382+
],
383+
})
384+
})
385+
})
251386
})

src/api/providers/gemini.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,9 +258,11 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
258258
try {
259259
const { id: model } = this.getModel()
260260

261+
// Wrap the parts in a proper Content structure with user role
262+
// The SDK expects Content[] format: [{ role: "user", parts: Part[] }]
261263
const response = await this.client.models.countTokens({
262264
model,
263-
contents: convertAnthropicContentToGemini(content),
265+
contents: [{ role: "user", parts: convertAnthropicContentToGemini(content) }],
264266
})
265267

266268
if (response.totalTokens === undefined) {

0 commit comments

Comments
 (0)