diff --git a/src/api/providers/__tests__/anthropic.spec.ts b/src/api/providers/__tests__/anthropic.spec.ts index b1d0a2f6b35..2da5f7cb10b 100644 --- a/src/api/providers/__tests__/anthropic.spec.ts +++ b/src/api/providers/__tests__/anthropic.spec.ts @@ -265,4 +265,62 @@ describe("AnthropicHandler", () => { expect(result.temperature).toBe(0) }) }) + + describe("countTokens", () => { + it("should return fallback count immediately without waiting for API", async () => { + // Mock the countTokens API to take a long time + const mockCountTokens = vitest.fn().mockImplementation(() => { + return new Promise((resolve) => { + setTimeout(() => resolve({ input_tokens: 100 }), 1000) + }) + }) + + ;(handler as any).client.messages.countTokens = mockCountTokens + + // Mock the base class countTokens to return a known value + const baseSpy = vitest.spyOn(handler.constructor.prototype.__proto__, "countTokens") + baseSpy.mockResolvedValue(50) + + const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }] + + const startTime = Date.now() + const result = await handler.countTokens(content) + const endTime = Date.now() + + // Should return immediately (less than 100ms) + expect(endTime - startTime).toBeLessThan(100) + + // Should return the fallback count + expect(result).toBe(50) + + // Should have called the base class method + expect(baseSpy).toHaveBeenCalledWith(content) + + // Should have started the async API call + expect(mockCountTokens).toHaveBeenCalled() + + baseSpy.mockRestore() + }) + + it("should handle async API errors gracefully", async () => { + // Mock the countTokens API to throw an error + const mockCountTokens = vitest.fn().mockRejectedValue(new Error("API Error")) + ;(handler as any).client.messages.countTokens = mockCountTokens + + // Mock the base class countTokens + const baseSpy = vitest.spyOn(handler.constructor.prototype.__proto__, "countTokens") + baseSpy.mockResolvedValue(75) + + const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }] + + // Should not throw even if async call fails + const result = await handler.countTokens(content) + expect(result).toBe(75) + + // Wait a bit to ensure async error is handled + await new Promise((resolve) => setTimeout(resolve, 100)) + + baseSpy.mockRestore() + }) + }) }) diff --git a/src/api/providers/__tests__/gemini.spec.ts b/src/api/providers/__tests__/gemini.spec.ts index 8a7fd24fe36..bff484d7738 100644 --- a/src/api/providers/__tests__/gemini.spec.ts +++ b/src/api/providers/__tests__/gemini.spec.ts @@ -247,4 +247,62 @@ describe("GeminiHandler", () => { expect(cost).toBeUndefined() }) }) + + describe("countTokens", () => { + it("should return fallback count immediately without waiting for API", async () => { + // Mock the countTokens API to take a long time + const mockCountTokens = vitest.fn().mockImplementation(() => { + return new Promise((resolve) => { + setTimeout(() => resolve({ totalTokens: 100 }), 1000) + }) + }) + + handler["client"].models.countTokens = mockCountTokens + + // Mock the base class countTokens to return a known value + const baseSpy = vitest.spyOn(handler.constructor.prototype.__proto__, "countTokens") + baseSpy.mockResolvedValue(50) + + const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }] + + const startTime = Date.now() + const result = await handler.countTokens(content) + const endTime = Date.now() + + // Should return immediately (less than 100ms) + expect(endTime - startTime).toBeLessThan(100) + + // Should return the fallback count + expect(result).toBe(50) + + // Should have called the base class method + expect(baseSpy).toHaveBeenCalledWith(content) + + // Should have started the async API call + expect(mockCountTokens).toHaveBeenCalled() + + baseSpy.mockRestore() + }) + + it("should handle async API errors gracefully", async () => { + // Mock the countTokens API to throw an error + const mockCountTokens = vitest.fn().mockRejectedValue(new Error("API Error")) + handler["client"].models.countTokens = mockCountTokens + + // Mock the base class countTokens + const baseSpy = vitest.spyOn(handler.constructor.prototype.__proto__, "countTokens") + baseSpy.mockResolvedValue(75) + + const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }] + + // Should not throw even if async call fails + const result = await handler.countTokens(content) + expect(result).toBe(75) + + // Wait a bit to ensure async error is handled + await new Promise((resolve) => setTimeout(resolve, 100)) + + baseSpy.mockRestore() + }) + }) }) diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index 52dec1ae55d..67650ec4fd4 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -278,22 +278,33 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa * @returns A promise resolving to the token count */ override async countTokens(content: Array): Promise { - try { - // Use the current model - const { id: model } = this.getModel() - - const response = await this.client.messages.countTokens({ - model, - messages: [{ role: "user", content: content }], - }) - - return response.input_tokens - } catch (error) { - // Log error but fallback to tiktoken estimation - console.warn("Anthropic token counting failed, using fallback", error) - - // Use the base provider's implementation as fallback - return super.countTokens(content) - } + // Immediately return the tiktoken estimate + const fallbackCount = super.countTokens(content) + + // Start the API call asynchronously (fire and forget) + this.countTokensAsync(content).catch((error) => { + // Log error but don't throw - we already returned the fallback + console.debug("Anthropic async token counting failed:", error) + }) + + return fallbackCount + } + + /** + * Performs the actual API call to count tokens asynchronously + * This method is called in the background and doesn't block the main request + */ + private async countTokensAsync(content: Array): Promise { + // Use the current model + const { id: model } = this.getModel() + + const response = await this.client.messages.countTokens({ + model, + messages: [{ role: "user", content: content }], + }) + + // In the future, we could cache this result or use it for telemetry + console.debug(`Anthropic token count: API=${response.input_tokens}`) + return response.input_tokens } } diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 6765c8676d8..544a9fd24b7 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -168,24 +168,38 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl } override async countTokens(content: Array): Promise { - try { - const { id: model } = this.getModel() + // Immediately return the tiktoken estimate + const fallbackCount = super.countTokens(content) - const response = await this.client.models.countTokens({ - model, - contents: convertAnthropicContentToGemini(content), - }) + // Start the API call asynchronously (fire and forget) + this.countTokensAsync(content).catch((error) => { + // Log error but don't throw - we already returned the fallback + console.debug("Gemini async token counting failed:", error) + }) - if (response.totalTokens === undefined) { - console.warn("Gemini token counting returned undefined, using fallback") - return super.countTokens(content) - } + return fallbackCount + } - return response.totalTokens - } catch (error) { - console.warn("Gemini token counting failed, using fallback", error) - return super.countTokens(content) + /** + * Performs the actual API call to count tokens asynchronously + * This method is called in the background and doesn't block the main request + */ + private async countTokensAsync(content: Array): Promise { + const { id: model } = this.getModel() + + const response = await this.client.models.countTokens({ + model, + contents: convertAnthropicContentToGemini(content), + }) + + if (response.totalTokens === undefined) { + console.debug("Gemini token counting returned undefined") + return 0 } + + // In the future, we could cache this result or use it for telemetry + console.debug(`Gemini token count: API=${response.totalTokens}`) + return response.totalTokens } public calculateCost({