Skip to content

Commit 0f08bce

Browse files
committed
fix: make token counting API requests asynchronous for Gemini/Anthropic
- Modified countTokens() in GeminiHandler to return tiktoken estimate immediately - Modified countTokens() in AnthropicHandler to return tiktoken estimate immediately - API calls now happen asynchronously in the background without blocking inference - Added tests to verify asynchronous behavior and immediate returns - Vertex provider automatically inherits the fix from GeminiHandler Fixes #3666
1 parent df6c57d commit 0f08bce

File tree

4 files changed

+172
-31
lines changed

4 files changed

+172
-31
lines changed

src/api/providers/__tests__/anthropic.spec.ts

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,4 +265,62 @@ describe("AnthropicHandler", () => {
265265
expect(result.temperature).toBe(0)
266266
})
267267
})
268+
269+
describe("countTokens", () => {
270+
it("should return fallback count immediately without waiting for API", async () => {
271+
// Mock the countTokens API to take a long time
272+
const mockCountTokens = vitest.fn().mockImplementation(() => {
273+
return new Promise((resolve) => {
274+
setTimeout(() => resolve({ input_tokens: 100 }), 1000)
275+
})
276+
})
277+
278+
;(handler as any).client.messages.countTokens = mockCountTokens
279+
280+
// Mock the base class countTokens to return a known value
281+
const baseSpy = vitest.spyOn(handler.constructor.prototype.__proto__, "countTokens")
282+
baseSpy.mockResolvedValue(50)
283+
284+
const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }]
285+
286+
const startTime = Date.now()
287+
const result = await handler.countTokens(content)
288+
const endTime = Date.now()
289+
290+
// Should return immediately (less than 100ms)
291+
expect(endTime - startTime).toBeLessThan(100)
292+
293+
// Should return the fallback count
294+
expect(result).toBe(50)
295+
296+
// Should have called the base class method
297+
expect(baseSpy).toHaveBeenCalledWith(content)
298+
299+
// Should have started the async API call
300+
expect(mockCountTokens).toHaveBeenCalled()
301+
302+
baseSpy.mockRestore()
303+
})
304+
305+
it("should handle async API errors gracefully", async () => {
306+
// Mock the countTokens API to throw an error
307+
const mockCountTokens = vitest.fn().mockRejectedValue(new Error("API Error"))
308+
;(handler as any).client.messages.countTokens = mockCountTokens
309+
310+
// Mock the base class countTokens
311+
const baseSpy = vitest.spyOn(handler.constructor.prototype.__proto__, "countTokens")
312+
baseSpy.mockResolvedValue(75)
313+
314+
const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }]
315+
316+
// Should not throw even if async call fails
317+
const result = await handler.countTokens(content)
318+
expect(result).toBe(75)
319+
320+
// Wait a bit to ensure async error is handled
321+
await new Promise((resolve) => setTimeout(resolve, 100))
322+
323+
baseSpy.mockRestore()
324+
})
325+
})
268326
})

src/api/providers/__tests__/gemini.spec.ts

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,4 +247,62 @@ describe("GeminiHandler", () => {
247247
expect(cost).toBeUndefined()
248248
})
249249
})
250+
251+
describe("countTokens", () => {
252+
it("should return fallback count immediately without waiting for API", async () => {
253+
// Mock the countTokens API to take a long time
254+
const mockCountTokens = vitest.fn().mockImplementation(() => {
255+
return new Promise((resolve) => {
256+
setTimeout(() => resolve({ totalTokens: 100 }), 1000)
257+
})
258+
})
259+
260+
handler["client"].models.countTokens = mockCountTokens
261+
262+
// Mock the base class countTokens to return a known value
263+
const baseSpy = vitest.spyOn(handler.constructor.prototype.__proto__, "countTokens")
264+
baseSpy.mockResolvedValue(50)
265+
266+
const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }]
267+
268+
const startTime = Date.now()
269+
const result = await handler.countTokens(content)
270+
const endTime = Date.now()
271+
272+
// Should return immediately (less than 100ms)
273+
expect(endTime - startTime).toBeLessThan(100)
274+
275+
// Should return the fallback count
276+
expect(result).toBe(50)
277+
278+
// Should have called the base class method
279+
expect(baseSpy).toHaveBeenCalledWith(content)
280+
281+
// Should have started the async API call
282+
expect(mockCountTokens).toHaveBeenCalled()
283+
284+
baseSpy.mockRestore()
285+
})
286+
287+
it("should handle async API errors gracefully", async () => {
288+
// Mock the countTokens API to throw an error
289+
const mockCountTokens = vitest.fn().mockRejectedValue(new Error("API Error"))
290+
handler["client"].models.countTokens = mockCountTokens
291+
292+
// Mock the base class countTokens
293+
const baseSpy = vitest.spyOn(handler.constructor.prototype.__proto__, "countTokens")
294+
baseSpy.mockResolvedValue(75)
295+
296+
const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }]
297+
298+
// Should not throw even if async call fails
299+
const result = await handler.countTokens(content)
300+
expect(result).toBe(75)
301+
302+
// Wait a bit to ensure async error is handled
303+
await new Promise((resolve) => setTimeout(resolve, 100))
304+
305+
baseSpy.mockRestore()
306+
})
307+
})
250308
})

src/api/providers/anthropic.ts

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -278,22 +278,33 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
278278
* @returns A promise resolving to the token count
279279
*/
280280
override async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
281-
try {
282-
// Use the current model
283-
const { id: model } = this.getModel()
284-
285-
const response = await this.client.messages.countTokens({
286-
model,
287-
messages: [{ role: "user", content: content }],
288-
})
289-
290-
return response.input_tokens
291-
} catch (error) {
292-
// Log error but fallback to tiktoken estimation
293-
console.warn("Anthropic token counting failed, using fallback", error)
294-
295-
// Use the base provider's implementation as fallback
296-
return super.countTokens(content)
297-
}
281+
// Immediately return the tiktoken estimate
282+
const fallbackCount = super.countTokens(content)
283+
284+
// Start the API call asynchronously (fire and forget)
285+
this.countTokensAsync(content).catch((error) => {
286+
// Log error but don't throw - we already returned the fallback
287+
console.debug("Anthropic async token counting failed:", error)
288+
})
289+
290+
return fallbackCount
291+
}
292+
293+
/**
294+
* Performs the actual API call to count tokens asynchronously
295+
* This method is called in the background and doesn't block the main request
296+
*/
297+
private async countTokensAsync(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
298+
// Use the current model
299+
const { id: model } = this.getModel()
300+
301+
const response = await this.client.messages.countTokens({
302+
model,
303+
messages: [{ role: "user", content: content }],
304+
})
305+
306+
// In the future, we could cache this result or use it for telemetry
307+
console.debug(`Anthropic token count: API=${response.input_tokens}`)
308+
return response.input_tokens
298309
}
299310
}

src/api/providers/gemini.ts

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -168,24 +168,38 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
168168
}
169169

170170
override async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
171-
try {
172-
const { id: model } = this.getModel()
171+
// Immediately return the tiktoken estimate
172+
const fallbackCount = super.countTokens(content)
173173

174-
const response = await this.client.models.countTokens({
175-
model,
176-
contents: convertAnthropicContentToGemini(content),
177-
})
174+
// Start the API call asynchronously (fire and forget)
175+
this.countTokensAsync(content).catch((error) => {
176+
// Log error but don't throw - we already returned the fallback
177+
console.debug("Gemini async token counting failed:", error)
178+
})
178179

179-
if (response.totalTokens === undefined) {
180-
console.warn("Gemini token counting returned undefined, using fallback")
181-
return super.countTokens(content)
182-
}
180+
return fallbackCount
181+
}
183182

184-
return response.totalTokens
185-
} catch (error) {
186-
console.warn("Gemini token counting failed, using fallback", error)
187-
return super.countTokens(content)
183+
/**
184+
* Performs the actual API call to count tokens asynchronously
185+
* This method is called in the background and doesn't block the main request
186+
*/
187+
private async countTokensAsync(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
188+
const { id: model } = this.getModel()
189+
190+
const response = await this.client.models.countTokens({
191+
model,
192+
contents: convertAnthropicContentToGemini(content),
193+
})
194+
195+
if (response.totalTokens === undefined) {
196+
console.debug("Gemini token counting returned undefined")
197+
return 0
188198
}
199+
200+
// In the future, we could cache this result or use it for telemetry
201+
console.debug(`Gemini token count: API=${response.totalTokens}`)
202+
return response.totalTokens
189203
}
190204

191205
public calculateCost({

0 commit comments

Comments
 (0)