diff --git a/src/api/providers/__tests__/unbound.test.ts b/src/api/providers/__tests__/unbound.test.ts index 7d11e6daad9..3a0fe4868ed 100644 --- a/src/api/providers/__tests__/unbound.test.ts +++ b/src/api/providers/__tests__/unbound.test.ts @@ -1,6 +1,5 @@ import { UnboundHandler } from "../unbound" import { ApiHandlerOptions } from "../../../shared/api" -import OpenAI from "openai" import { Anthropic } from "@anthropic-ai/sdk" // Mock OpenAI client @@ -16,6 +15,7 @@ jest.mock("openai", () => { create: (...args: any[]) => { const stream = { [Symbol.asyncIterator]: async function* () { + // First chunk with content yield { choices: [ { @@ -24,13 +24,25 @@ jest.mock("openai", () => { }, ], } + // Second chunk with usage data yield { - choices: [ - { - delta: {}, - index: 0, - }, - ], + choices: [{ delta: {}, index: 0 }], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + // Third chunk with cache usage data + yield { + choices: [{ delta: {}, index: 0 }], + usage: { + prompt_tokens: 8, + completion_tokens: 4, + total_tokens: 12, + cache_creation_input_tokens: 3, + cache_read_input_tokens: 2, + }, } }, } @@ -95,19 +107,37 @@ describe("UnboundHandler", () => { }, ] - it("should handle streaming responses", async () => { + it("should handle streaming responses with text and usage data", async () => { const stream = handler.createMessage(systemPrompt, messages) - const chunks: any[] = [] + const chunks: Array<{ type: string } & Record> = [] for await (const chunk of stream) { chunks.push(chunk) } - expect(chunks.length).toBe(1) + expect(chunks.length).toBe(3) + + // Verify text chunk expect(chunks[0]).toEqual({ type: "text", text: "Test response", }) + // Verify regular usage data + expect(chunks[1]).toEqual({ + type: "usage", + inputTokens: 10, + outputTokens: 5, + }) + + // Verify usage data with cache information + expect(chunks[2]).toEqual({ + type: "usage", + inputTokens: 8, + outputTokens: 4, + cacheWriteTokens: 3, + cacheReadTokens: 2, + }) + expect(mockCreate).toHaveBeenCalledWith( expect.objectContaining({ model: "claude-3-5-sonnet-20241022", diff --git a/src/api/providers/unbound.ts b/src/api/providers/unbound.ts index 23e419c0b1d..2bc3d82822a 100644 --- a/src/api/providers/unbound.ts +++ b/src/api/providers/unbound.ts @@ -3,7 +3,12 @@ import OpenAI from "openai" import { ApiHandler, SingleCompletionHandler } from "../" import { ApiHandlerOptions, ModelInfo, UnboundModelId, unboundDefaultModelId, unboundModels } from "../../shared/api" import { convertToOpenAiMessages } from "../transform/openai-format" -import { ApiStream } from "../transform/stream" +import { ApiStream, ApiStreamUsageChunk } from "../transform/stream" + +interface UnboundUsage extends OpenAI.CompletionUsage { + cache_creation_input_tokens?: number + cache_read_input_tokens?: number +} export class UnboundHandler implements ApiHandler, SingleCompletionHandler { private options: ApiHandlerOptions @@ -96,7 +101,7 @@ export class UnboundHandler implements ApiHandler, SingleCompletionHandler { for await (const chunk of completion) { const delta = chunk.choices[0]?.delta - const usage = chunk.usage + const usage = chunk.usage as UnboundUsage if (delta?.content) { yield { @@ -106,11 +111,21 @@ export class UnboundHandler implements ApiHandler, SingleCompletionHandler { } if (usage) { - yield { + const usageData: ApiStreamUsageChunk = { type: "usage", - inputTokens: usage?.prompt_tokens || 0, - outputTokens: usage?.completion_tokens || 0, + inputTokens: usage.prompt_tokens || 0, + outputTokens: usage.completion_tokens || 0, } + + // Only add cache tokens if they exist + if (usage.cache_creation_input_tokens) { + usageData.cacheWriteTokens = usage.cache_creation_input_tokens + } + if (usage.cache_read_input_tokens) { + usageData.cacheReadTokens = usage.cache_read_input_tokens + } + + yield usageData } } }