diff --git a/packages/types/src/message.ts b/packages/types/src/message.ts index 49d5203c664..0c87655fc0c 100644 --- a/packages/types/src/message.ts +++ b/packages/types/src/message.ts @@ -155,6 +155,7 @@ export const clineMessageSchema = z.object({ progressStatus: toolProgressStatusSchema.optional(), contextCondense: contextCondenseSchema.optional(), isProtected: z.boolean().optional(), + apiProtocol: z.union([z.literal("openai"), z.literal("anthropic")]).optional(), }) export type ClineMessage = z.infer diff --git a/packages/types/src/provider-settings.ts b/packages/types/src/provider-settings.ts index e940ececd1d..0d6dd613070 100644 --- a/packages/types/src/provider-settings.ts +++ b/packages/types/src/provider-settings.ts @@ -292,3 +292,11 @@ export const getModelId = (settings: ProviderSettings): string | undefined => { const modelIdKey = MODEL_ID_KEYS.find((key) => settings[key]) return modelIdKey ? (settings[modelIdKey] as string) : undefined } + +// Providers that use Anthropic-style API protocol +export const ANTHROPIC_STYLE_PROVIDERS: ProviderName[] = ["anthropic", "claude-code"] + +// Helper function to determine API protocol for a provider +export const getApiProtocol = (provider: ProviderName | undefined): "anthropic" | "openai" => { + return provider && ANTHROPIC_STYLE_PROVIDERS.includes(provider) ? "anthropic" : "openai" +} diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 31260cd6fa2..c8553a8fc67 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -21,6 +21,7 @@ import { type HistoryItem, TelemetryEventName, TodoItem, + getApiProtocol, } from "@roo-code/types" import { TelemetryService } from "@roo-code/telemetry" import { CloudService } from "@roo-code/cloud" @@ -1207,11 +1208,16 @@ export class Task extends EventEmitter { // top-down build file structure of project which for large projects can // take a few seconds. For the best UX we show a placeholder api_req_started // message with a loading spinner as this happens. + + // Determine API protocol based on provider + const apiProtocol = getApiProtocol(this.apiConfiguration.apiProvider) + await this.say( "api_req_started", JSON.stringify({ request: userContent.map((block) => formatContentBlockToMarkdown(block)).join("\n\n") + "\n\nLoading...", + apiProtocol, }), ) @@ -1243,6 +1249,7 @@ export class Task extends EventEmitter { this.clineMessages[lastApiReqIndex].text = JSON.stringify({ request: finalUserContent.map((block) => formatContentBlockToMarkdown(block)).join("\n\n"), + apiProtocol, } satisfies ClineApiReqInfo) await this.saveClineMessages() @@ -1263,8 +1270,9 @@ export class Task extends EventEmitter { // of prices in tasks from history (it's worth removing a few months // from now). const updateApiReqMsg = (cancelReason?: ClineApiReqCancelReason, streamingFailedMessage?: string) => { + const existingData = JSON.parse(this.clineMessages[lastApiReqIndex].text || "{}") this.clineMessages[lastApiReqIndex].text = JSON.stringify({ - ...JSON.parse(this.clineMessages[lastApiReqIndex].text || "{}"), + ...existingData, tokensIn: inputTokens, tokensOut: outputTokens, cacheWrites: cacheWriteTokens, diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index 953c0c1070e..8aead4674a5 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -379,6 +379,7 @@ export interface ClineApiReqInfo { cost?: number cancelReason?: ClineApiReqCancelReason streamingFailedMessage?: string + apiProtocol?: "anthropic" | "openai" } export type ClineApiReqCancelReason = "streaming_failed" | "user_cancelled" diff --git a/src/shared/__tests__/getApiMetrics.spec.ts b/src/shared/__tests__/getApiMetrics.spec.ts index a1b1eecaed8..02f45c5cc4c 100644 --- a/src/shared/__tests__/getApiMetrics.spec.ts +++ b/src/shared/__tests__/getApiMetrics.spec.ts @@ -61,7 +61,7 @@ describe("getApiMetrics", () => { expect(result.totalCacheWrites).toBe(5) expect(result.totalCacheReads).toBe(10) expect(result.totalCost).toBe(0.005) - expect(result.contextTokens).toBe(315) // 100 + 200 + 5 + 10 + expect(result.contextTokens).toBe(300) // 100 + 200 (OpenAI default, no cache tokens) }) it("should calculate metrics from multiple api_req_started messages", () => { @@ -83,7 +83,7 @@ describe("getApiMetrics", () => { expect(result.totalCacheWrites).toBe(8) // 5 + 3 expect(result.totalCacheReads).toBe(17) // 10 + 7 expect(result.totalCost).toBe(0.008) // 0.005 + 0.003 - expect(result.contextTokens).toBe(210) // 50 + 150 + 3 + 7 (from the last message) + expect(result.contextTokens).toBe(200) // 50 + 150 (OpenAI default, no cache tokens) }) it("should calculate metrics from condense_context messages", () => { @@ -123,7 +123,7 @@ describe("getApiMetrics", () => { expect(result.totalCacheWrites).toBe(8) // 5 + 3 expect(result.totalCacheReads).toBe(17) // 10 + 7 expect(result.totalCost).toBe(0.01) // 0.005 + 0.002 + 0.003 - expect(result.contextTokens).toBe(210) // 50 + 150 + 3 + 7 (from the last api_req_started message) + expect(result.contextTokens).toBe(200) // 50 + 150 (OpenAI default, no cache tokens) }) }) @@ -242,9 +242,9 @@ describe("getApiMetrics", () => { expect(result.totalCacheReads).toBe(10) expect(result.totalCost).toBe(0.005) - // The implementation will use the last message with tokens for contextTokens - // In this case, it's the cacheReads message - expect(result.contextTokens).toBe(10) + // The implementation will use the last message that has any tokens + // In this case, it's the message with tokensOut:200 (since the last few messages have no tokensIn/Out) + expect(result.contextTokens).toBe(200) // 0 + 200 (from the tokensOut message) }) it("should handle non-number values in api_req_started message", () => { @@ -264,8 +264,8 @@ describe("getApiMetrics", () => { expect(result.totalCacheReads).toBeUndefined() expect(result.totalCost).toBe(0) - // The implementation concatenates string values for contextTokens - expect(result.contextTokens).toBe("not-a-numbernot-a-numbernot-a-numbernot-a-number") + // The implementation concatenates all token values including cache tokens + expect(result.contextTokens).toBe("not-a-numbernot-a-number") // tokensIn + tokensOut (OpenAI default) }) }) @@ -279,7 +279,7 @@ describe("getApiMetrics", () => { const result = getApiMetrics(messages) // Should use the values from the last api_req_started message - expect(result.contextTokens).toBe(210) // 50 + 150 + 3 + 7 + expect(result.contextTokens).toBe(200) // 50 + 150 (OpenAI default, no cache tokens) }) it("should calculate contextTokens from the last condense_context message", () => { @@ -305,7 +305,7 @@ describe("getApiMetrics", () => { const result = getApiMetrics(messages) // Should use the values from the last api_req_started message - expect(result.contextTokens).toBe(210) // 50 + 150 + 3 + 7 + expect(result.contextTokens).toBe(200) // 50 + 150 (OpenAI default, no cache tokens) }) it("should handle missing values when calculating contextTokens", () => { @@ -320,7 +320,7 @@ describe("getApiMetrics", () => { const result = getApiMetrics(messages) // Should handle missing or invalid values - expect(result.contextTokens).toBe(15) // 0 + 0 + 5 + 10 + expect(result.contextTokens).toBe(0) // 0 + 0 (OpenAI default, no cache tokens) // Restore console.error console.error = originalConsoleError diff --git a/src/shared/getApiMetrics.ts b/src/shared/getApiMetrics.ts index 49476fdbb6d..dcd9ae9efe6 100644 --- a/src/shared/getApiMetrics.ts +++ b/src/shared/getApiMetrics.ts @@ -6,6 +6,7 @@ export type ParsedApiReqStartedTextType = { cacheWrites: number cacheReads: number cost?: number // Only present if combineApiRequests has been called + apiProtocol?: "anthropic" | "openai" } /** @@ -72,8 +73,15 @@ export function getApiMetrics(messages: ClineMessage[]) { if (message.type === "say" && message.say === "api_req_started" && message.text) { try { const parsedText: ParsedApiReqStartedTextType = JSON.parse(message.text) - const { tokensIn, tokensOut, cacheWrites, cacheReads } = parsedText - result.contextTokens = (tokensIn || 0) + (tokensOut || 0) + (cacheWrites || 0) + (cacheReads || 0) + const { tokensIn, tokensOut, cacheWrites, cacheReads, apiProtocol } = parsedText + + // Calculate context tokens based on API protocol + if (apiProtocol === "anthropic") { + result.contextTokens = (tokensIn || 0) + (tokensOut || 0) + (cacheWrites || 0) + (cacheReads || 0) + } else { + // For OpenAI (or when protocol is not specified) + result.contextTokens = (tokensIn || 0) + (tokensOut || 0) + } } catch (error) { console.error("Error parsing JSON:", error) continue