Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions src/api/providers/__tests__/anthropic.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -265,4 +265,62 @@ describe("AnthropicHandler", () => {
expect(result.temperature).toBe(0)
})
})

describe("countTokens", () => {
it("should return fallback count immediately without waiting for API", async () => {
// Mock the countTokens API to take a long time
const mockCountTokens = vitest.fn().mockImplementation(() => {
return new Promise((resolve) => {
setTimeout(() => resolve({ input_tokens: 100 }), 1000)
})
})

;(handler as any).client.messages.countTokens = mockCountTokens

// Mock the base class countTokens to return a known value
const baseSpy = vitest.spyOn(handler.constructor.prototype.__proto__, "countTokens")
baseSpy.mockResolvedValue(50)

const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }]

const startTime = Date.now()
const result = await handler.countTokens(content)
const endTime = Date.now()

// Should return immediately (less than 100ms)
expect(endTime - startTime).toBeLessThan(100)

// Should return the fallback count
expect(result).toBe(50)

// Should have called the base class method
expect(baseSpy).toHaveBeenCalledWith(content)

// Should have started the async API call
expect(mockCountTokens).toHaveBeenCalled()

baseSpy.mockRestore()
})

it("should handle async API errors gracefully", async () => {
// Mock the countTokens API to throw an error
const mockCountTokens = vitest.fn().mockRejectedValue(new Error("API Error"))
;(handler as any).client.messages.countTokens = mockCountTokens

// Mock the base class countTokens
const baseSpy = vitest.spyOn(handler.constructor.prototype.__proto__, "countTokens")
baseSpy.mockResolvedValue(75)

const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }]

// Should not throw even if async call fails
const result = await handler.countTokens(content)
expect(result).toBe(75)

// Wait a bit to ensure async error is handled
await new Promise((resolve) => setTimeout(resolve, 100))

baseSpy.mockRestore()
})
})
})
58 changes: 58 additions & 0 deletions src/api/providers/__tests__/gemini.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,62 @@ describe("GeminiHandler", () => {
expect(cost).toBeUndefined()
})
})

describe("countTokens", () => {
it("should return fallback count immediately without waiting for API", async () => {
// Mock the countTokens API to take a long time
const mockCountTokens = vitest.fn().mockImplementation(() => {
return new Promise((resolve) => {
setTimeout(() => resolve({ totalTokens: 100 }), 1000)
})
})

handler["client"].models.countTokens = mockCountTokens

// Mock the base class countTokens to return a known value
const baseSpy = vitest.spyOn(handler.constructor.prototype.__proto__, "countTokens")
baseSpy.mockResolvedValue(50)

const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }]

const startTime = Date.now()
const result = await handler.countTokens(content)
const endTime = Date.now()

// Should return immediately (less than 100ms)
expect(endTime - startTime).toBeLessThan(100)

// Should return the fallback count
expect(result).toBe(50)

// Should have called the base class method
expect(baseSpy).toHaveBeenCalledWith(content)

// Should have started the async API call
expect(mockCountTokens).toHaveBeenCalled()

baseSpy.mockRestore()
})

it("should handle async API errors gracefully", async () => {
// Mock the countTokens API to throw an error
const mockCountTokens = vitest.fn().mockRejectedValue(new Error("API Error"))
handler["client"].models.countTokens = mockCountTokens

// Mock the base class countTokens
const baseSpy = vitest.spyOn(handler.constructor.prototype.__proto__, "countTokens")
baseSpy.mockResolvedValue(75)

const content: Anthropic.Messages.ContentBlockParam[] = [{ type: "text", text: "Test content" }]

// Should not throw even if async call fails
const result = await handler.countTokens(content)
expect(result).toBe(75)

// Wait a bit to ensure async error is handled
await new Promise((resolve) => setTimeout(resolve, 100))

baseSpy.mockRestore()
})
})
})
45 changes: 28 additions & 17 deletions src/api/providers/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -278,22 +278,33 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
* @returns A promise resolving to the token count
*/
override async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
try {
// Use the current model
const { id: model } = this.getModel()

const response = await this.client.messages.countTokens({
model,
messages: [{ role: "user", content: content }],
})

return response.input_tokens
} catch (error) {
// Log error but fallback to tiktoken estimation
console.warn("Anthropic token counting failed, using fallback", error)

// Use the base provider's implementation as fallback
return super.countTokens(content)
}
// Immediately return the tiktoken estimate
const fallbackCount = super.countTokens(content)

// Start the API call asynchronously (fire and forget)
this.countTokensAsync(content).catch((error) => {
// Log error but don't throw - we already returned the fallback
console.debug("Anthropic async token counting failed:", error)
})

return fallbackCount
}

/**
* Performs the actual API call to count tokens asynchronously
* This method is called in the background and doesn't block the main request
*/
private async countTokensAsync(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
// Use the current model
const { id: model } = this.getModel()

const response = await this.client.messages.countTokens({
model,
messages: [{ role: "user", content: content }],
})

// In the future, we could cache this result or use it for telemetry
console.debug(`Anthropic token count: API=${response.input_tokens}`)
return response.input_tokens
}
}
42 changes: 28 additions & 14 deletions src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -168,24 +168,38 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
}

override async countTokens(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
try {
const { id: model } = this.getModel()
// Immediately return the tiktoken estimate
const fallbackCount = super.countTokens(content)

const response = await this.client.models.countTokens({
model,
contents: convertAnthropicContentToGemini(content),
})
// Start the API call asynchronously (fire and forget)
this.countTokensAsync(content).catch((error) => {
// Log error but don't throw - we already returned the fallback
console.debug("Gemini async token counting failed:", error)
})

if (response.totalTokens === undefined) {
console.warn("Gemini token counting returned undefined, using fallback")
return super.countTokens(content)
}
return fallbackCount
}

return response.totalTokens
} catch (error) {
console.warn("Gemini token counting failed, using fallback", error)
return super.countTokens(content)
/**
* Performs the actual API call to count tokens asynchronously
* This method is called in the background and doesn't block the main request
*/
private async countTokensAsync(content: Array<Anthropic.Messages.ContentBlockParam>): Promise<number> {
const { id: model } = this.getModel()

const response = await this.client.models.countTokens({
model,
contents: convertAnthropicContentToGemini(content),
})

if (response.totalTokens === undefined) {
console.debug("Gemini token counting returned undefined")
return 0
}

// In the future, we could cache this result or use it for telemetry
console.debug(`Gemini token count: API=${response.totalTokens}`)
return response.totalTokens
}

public calculateCost({
Expand Down