From f306461276fbad2da4a5d693f970b2dc0b478798 Mon Sep 17 00:00:00 2001 From: Matt Rubens Date: Mon, 10 Mar 2025 22:53:40 -0400 Subject: [PATCH] Fix usage tracking for SiliconFlow etc --- .changeset/tidy-queens-pay.md | 5 + .../__tests__/openai-usage-tracking.test.ts | 235 ++++++++++++++++++ src/api/providers/openai.ts | 8 +- 3 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 .changeset/tidy-queens-pay.md create mode 100644 src/api/providers/__tests__/openai-usage-tracking.test.ts diff --git a/.changeset/tidy-queens-pay.md b/.changeset/tidy-queens-pay.md new file mode 100644 index 00000000000..750a58c7892 --- /dev/null +++ b/.changeset/tidy-queens-pay.md @@ -0,0 +1,5 @@ +--- +"roo-cline": patch +--- + +Fix usage tracking for SiliconFlow etc diff --git a/src/api/providers/__tests__/openai-usage-tracking.test.ts b/src/api/providers/__tests__/openai-usage-tracking.test.ts new file mode 100644 index 00000000000..6df9a0bca50 --- /dev/null +++ b/src/api/providers/__tests__/openai-usage-tracking.test.ts @@ -0,0 +1,235 @@ +import { OpenAiHandler } from "../openai" +import { ApiHandlerOptions } from "../../../shared/api" +import { Anthropic } from "@anthropic-ai/sdk" + +// Mock OpenAI client with multiple chunks that contain usage data +const mockCreate = jest.fn() +jest.mock("openai", () => { + return { + __esModule: true, + default: jest.fn().mockImplementation(() => ({ + chat: { + completions: { + create: mockCreate.mockImplementation(async (options) => { + if (!options.stream) { + return { + id: "test-completion", + choices: [ + { + message: { role: "assistant", content: "Test response", refusal: null }, + finish_reason: "stop", + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + } + + // Return a stream with multiple chunks that include usage metrics + return { + [Symbol.asyncIterator]: async function* () { + // First chunk with partial usage + yield { + choices: [ + { + delta: { content: "Test " }, + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 2, + total_tokens: 12, + }, + } + + // Second chunk with updated usage + yield { + choices: [ + { + delta: { content: "response" }, + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 4, + total_tokens: 14, + }, + } + + // Final chunk with complete usage + yield { + choices: [ + { + delta: {}, + index: 0, + }, + ], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + }, + } + }), + }, + }, + })), + } +}) + +describe("OpenAiHandler with usage tracking fix", () => { + let handler: OpenAiHandler + let mockOptions: ApiHandlerOptions + + beforeEach(() => { + mockOptions = { + openAiApiKey: "test-api-key", + openAiModelId: "gpt-4", + openAiBaseUrl: "https://api.openai.com/v1", + } + handler = new OpenAiHandler(mockOptions) + mockCreate.mockClear() + }) + + describe("usage metrics with streaming", () => { + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: [ + { + type: "text" as const, + text: "Hello!", + }, + ], + }, + ] + + it("should only yield usage metrics once at the end of the stream", async () => { + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Check we have text chunks + const textChunks = chunks.filter((chunk) => chunk.type === "text") + expect(textChunks).toHaveLength(2) + expect(textChunks[0].text).toBe("Test ") + expect(textChunks[1].text).toBe("response") + + // Check we only have one usage chunk and it's the last one + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks).toHaveLength(1) + expect(usageChunks[0]).toEqual({ + type: "usage", + inputTokens: 10, + outputTokens: 5, + }) + + // Check the usage chunk is the last one reported from the API + const lastChunk = chunks[chunks.length - 1] + expect(lastChunk.type).toBe("usage") + expect(lastChunk.inputTokens).toBe(10) + expect(lastChunk.outputTokens).toBe(5) + }) + + it("should handle case where usage is only in the final chunk", async () => { + // Override the mock for this specific test + mockCreate.mockImplementationOnce(async (options) => { + if (!options.stream) { + return { + id: "test-completion", + choices: [{ message: { role: "assistant", content: "Test response" } }], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + } + } + + return { + [Symbol.asyncIterator]: async function* () { + // First chunk with no usage + yield { + choices: [{ delta: { content: "Test " }, index: 0 }], + usage: null, + } + + // Second chunk with no usage + yield { + choices: [{ delta: { content: "response" }, index: 0 }], + usage: null, + } + + // Final chunk with usage data + yield { + choices: [{ delta: {}, index: 0 }], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + } + }, + } + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Check usage metrics + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks).toHaveLength(1) + expect(usageChunks[0]).toEqual({ + type: "usage", + inputTokens: 10, + outputTokens: 5, + }) + }) + + it("should handle case where no usage is provided", async () => { + // Override the mock for this specific test + mockCreate.mockImplementationOnce(async (options) => { + if (!options.stream) { + return { + id: "test-completion", + choices: [{ message: { role: "assistant", content: "Test response" } }], + usage: null, + } + } + + return { + [Symbol.asyncIterator]: async function* () { + yield { + choices: [{ delta: { content: "Test response" }, index: 0 }], + usage: null, + } + yield { + choices: [{ delta: {}, index: 0 }], + usage: null, + } + }, + } + }) + + const stream = handler.createMessage(systemPrompt, messages) + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + // Check we don't have any usage chunks + const usageChunks = chunks.filter((chunk) => chunk.type === "usage") + expect(usageChunks).toHaveLength(0) + }) + }) +}) diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 2af3f2da05a..a6e3eb14881 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -99,6 +99,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl const stream = await this.client.chat.completions.create(requestOptions) + let lastUsage + for await (const chunk of stream) { const delta = chunk.choices[0]?.delta ?? {} @@ -116,9 +118,13 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } } if (chunk.usage) { - yield this.processUsageMetrics(chunk.usage, modelInfo) + lastUsage = chunk.usage } } + + if (lastUsage) { + yield this.processUsageMetrics(lastUsage, modelInfo) + } } else { // o1 for instance doesnt support streaming, non-1 temp, or system prompt const systemMessage: OpenAI.Chat.ChatCompletionUserMessageParam = {