Skip to content

Commit 416fa57

Browse files
authored
Fix cost and token tracking between provider styles (#8954)
1 parent 4a096e1 commit 416fa57

File tree

11 files changed

+180
-94
lines changed

11 files changed

+180
-94
lines changed

src/api/providers/anthropic.ts

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,17 +230,19 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa
230230
}
231231

232232
if (inputTokens > 0 || outputTokens > 0 || cacheWriteTokens > 0 || cacheReadTokens > 0) {
233+
const { totalCost } = calculateApiCostAnthropic(
234+
this.getModel().info,
235+
inputTokens,
236+
outputTokens,
237+
cacheWriteTokens,
238+
cacheReadTokens,
239+
)
240+
233241
yield {
234242
type: "usage",
235243
inputTokens: 0,
236244
outputTokens: 0,
237-
totalCost: calculateApiCostAnthropic(
238-
this.getModel().info,
239-
inputTokens,
240-
outputTokens,
241-
cacheWriteTokens,
242-
cacheReadTokens,
243-
),
245+
totalCost,
244246
}
245247
}
246248
}

src/api/providers/cerebras.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ export class CerebrasHandler extends BaseProvider implements SingleCompletionHan
331331
const { info } = this.getModel()
332332
// Use actual token usage from the last request
333333
const { inputTokens, outputTokens } = this.lastUsage
334-
return calculateApiCostOpenAI(info, inputTokens, outputTokens)
334+
const { totalCost } = calculateApiCostOpenAI(info, inputTokens, outputTokens)
335+
return totalCost
335336
}
336337
}

src/api/providers/deepinfra.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ export class DeepInfraHandler extends RouterProvider implements SingleCompletion
131131
const cacheWriteTokens = usage?.prompt_tokens_details?.cache_write_tokens || 0
132132
const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0
133133

134-
const totalCost = modelInfo
134+
const { totalCost } = modelInfo
135135
? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
136-
: 0
136+
: { totalCost: 0 }
137137

138138
return {
139139
type: "usage",

src/api/providers/groq.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ export class GroqHandler extends BaseOpenAiCompatibleProvider<GroqModelId> {
6464
const cacheWriteTokens = 0
6565

6666
// Calculate cost using OpenAI-compatible cost calculation
67-
const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
67+
const { totalCost } = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
6868

6969
yield {
7070
type: "usage",

src/api/providers/lite-llm.ts

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -165,22 +165,23 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa
165165
(lastUsage as any).prompt_cache_hit_tokens ||
166166
0
167167

168+
const { totalCost } = calculateApiCostOpenAI(
169+
info,
170+
lastUsage.prompt_tokens || 0,
171+
lastUsage.completion_tokens || 0,
172+
cacheWriteTokens,
173+
cacheReadTokens,
174+
)
175+
168176
const usageData: ApiStreamUsageChunk = {
169177
type: "usage",
170178
inputTokens: lastUsage.prompt_tokens || 0,
171179
outputTokens: lastUsage.completion_tokens || 0,
172180
cacheWriteTokens: cacheWriteTokens > 0 ? cacheWriteTokens : undefined,
173181
cacheReadTokens: cacheReadTokens > 0 ? cacheReadTokens : undefined,
182+
totalCost,
174183
}
175184

176-
usageData.totalCost = calculateApiCostOpenAI(
177-
info,
178-
usageData.inputTokens,
179-
usageData.outputTokens,
180-
usageData.cacheWriteTokens || 0,
181-
usageData.cacheReadTokens || 0,
182-
)
183-
184185
yield usageData
185186
}
186187
} catch (error) {

src/api/providers/openai-native.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
9999
const effectiveInfo = this.applyServiceTierPricing(model.info, effectiveTier)
100100

101101
// Pass total input tokens directly to calculateApiCostOpenAI
102-
// The function handles subtracting both cache reads and writes internally (see shared/cost.ts:46)
103-
const totalCost = calculateApiCostOpenAI(
102+
// The function handles subtracting both cache reads and writes internally
103+
const { totalCost } = calculateApiCostOpenAI(
104104
effectiveInfo,
105105
totalInputTokens,
106106
totalOutputTokens,

src/api/providers/requesty.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan
8585
const outputTokens = requestyUsage?.completion_tokens || 0
8686
const cacheWriteTokens = requestyUsage?.prompt_tokens_details?.caching_tokens || 0
8787
const cacheReadTokens = requestyUsage?.prompt_tokens_details?.cached_tokens || 0
88-
const totalCost = modelInfo
88+
const { totalCost } = modelInfo
8989
? calculateApiCostOpenAI(modelInfo, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
90-
: 0
90+
: { totalCost: 0 }
9191

9292
return {
9393
type: "usage",

src/core/task/Task.ts

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ import { RooTerminalProcess } from "../../integrations/terminal/types"
7474
import { TerminalRegistry } from "../../integrations/terminal/TerminalRegistry"
7575

7676
// utils
77-
import { calculateApiCostAnthropic } from "../../shared/cost"
77+
import { calculateApiCostAnthropic, calculateApiCostOpenAI } from "../../shared/cost"
7878
import { getWorkspacePath } from "../../utils/path"
7979

8080
// prompts
@@ -1886,21 +1886,35 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
18861886
}
18871887

18881888
const existingData = JSON.parse(this.clineMessages[lastApiReqIndex].text || "{}")
1889+
1890+
// Calculate total tokens and cost using provider-aware function
1891+
const modelId = getModelId(this.apiConfiguration)
1892+
const apiProtocol = getApiProtocol(this.apiConfiguration.apiProvider, modelId)
1893+
1894+
const costResult =
1895+
apiProtocol === "anthropic"
1896+
? calculateApiCostAnthropic(
1897+
this.api.getModel().info,
1898+
inputTokens,
1899+
outputTokens,
1900+
cacheWriteTokens,
1901+
cacheReadTokens,
1902+
)
1903+
: calculateApiCostOpenAI(
1904+
this.api.getModel().info,
1905+
inputTokens,
1906+
outputTokens,
1907+
cacheWriteTokens,
1908+
cacheReadTokens,
1909+
)
1910+
18891911
this.clineMessages[lastApiReqIndex].text = JSON.stringify({
18901912
...existingData,
1891-
tokensIn: inputTokens,
1892-
tokensOut: outputTokens,
1913+
tokensIn: costResult.totalInputTokens,
1914+
tokensOut: costResult.totalOutputTokens,
18931915
cacheWrites: cacheWriteTokens,
18941916
cacheReads: cacheReadTokens,
1895-
cost:
1896-
totalCost ??
1897-
calculateApiCostAnthropic(
1898-
this.api.getModel().info,
1899-
inputTokens,
1900-
outputTokens,
1901-
cacheWriteTokens,
1902-
cacheReadTokens,
1903-
),
1917+
cost: totalCost ?? costResult.totalCost,
19041918
cancelReason,
19051919
streamingFailedMessage,
19061920
} satisfies ClineApiReqInfo)
@@ -2104,21 +2118,34 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
21042118
await this.updateClineMessage(apiReqMessage)
21052119
}
21062120

2107-
// Capture telemetry
2121+
// Capture telemetry with provider-aware cost calculation
2122+
const modelId = getModelId(this.apiConfiguration)
2123+
const apiProtocol = getApiProtocol(this.apiConfiguration.apiProvider, modelId)
2124+
2125+
// Use the appropriate cost function based on the API protocol
2126+
const costResult =
2127+
apiProtocol === "anthropic"
2128+
? calculateApiCostAnthropic(
2129+
this.api.getModel().info,
2130+
tokens.input,
2131+
tokens.output,
2132+
tokens.cacheWrite,
2133+
tokens.cacheRead,
2134+
)
2135+
: calculateApiCostOpenAI(
2136+
this.api.getModel().info,
2137+
tokens.input,
2138+
tokens.output,
2139+
tokens.cacheWrite,
2140+
tokens.cacheRead,
2141+
)
2142+
21082143
TelemetryService.instance.captureLlmCompletion(this.taskId, {
2109-
inputTokens: tokens.input,
2110-
outputTokens: tokens.output,
2144+
inputTokens: costResult.totalInputTokens,
2145+
outputTokens: costResult.totalOutputTokens,
21112146
cacheWriteTokens: tokens.cacheWrite,
21122147
cacheReadTokens: tokens.cacheRead,
2113-
cost:
2114-
tokens.total ??
2115-
calculateApiCostAnthropic(
2116-
this.api.getModel().info,
2117-
tokens.input,
2118-
tokens.output,
2119-
tokens.cacheWrite,
2120-
tokens.cacheRead,
2121-
),
2148+
cost: tokens.total ?? costResult.totalCost,
21222149
})
21232150
}
21242151
}

src/shared/cost.ts

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,31 @@
11
import type { ModelInfo } from "@roo-code/types"
22

3+
export interface ApiCostResult {
4+
totalInputTokens: number
5+
totalOutputTokens: number
6+
totalCost: number
7+
}
8+
39
function calculateApiCostInternal(
410
modelInfo: ModelInfo,
511
inputTokens: number,
612
outputTokens: number,
713
cacheCreationInputTokens: number,
814
cacheReadInputTokens: number,
9-
): number {
15+
totalInputTokens: number,
16+
totalOutputTokens: number,
17+
): ApiCostResult {
1018
const cacheWritesCost = ((modelInfo.cacheWritesPrice || 0) / 1_000_000) * cacheCreationInputTokens
1119
const cacheReadsCost = ((modelInfo.cacheReadsPrice || 0) / 1_000_000) * cacheReadInputTokens
1220
const baseInputCost = ((modelInfo.inputPrice || 0) / 1_000_000) * inputTokens
1321
const outputCost = ((modelInfo.outputPrice || 0) / 1_000_000) * outputTokens
1422
const totalCost = cacheWritesCost + cacheReadsCost + baseInputCost + outputCost
15-
return totalCost
23+
24+
return {
25+
totalInputTokens,
26+
totalOutputTokens,
27+
totalCost,
28+
}
1629
}
1730

1831
// For Anthropic compliant usage, the input tokens count does NOT include the
@@ -23,13 +36,22 @@ export function calculateApiCostAnthropic(
2336
outputTokens: number,
2437
cacheCreationInputTokens?: number,
2538
cacheReadInputTokens?: number,
26-
): number {
39+
): ApiCostResult {
40+
const cacheCreation = cacheCreationInputTokens || 0
41+
const cacheRead = cacheReadInputTokens || 0
42+
43+
// For Anthropic: inputTokens does NOT include cached tokens
44+
// Total input = base input + cache creation + cache reads
45+
const totalInputTokens = inputTokens + cacheCreation + cacheRead
46+
2747
return calculateApiCostInternal(
2848
modelInfo,
2949
inputTokens,
3050
outputTokens,
31-
cacheCreationInputTokens || 0,
32-
cacheReadInputTokens || 0,
51+
cacheCreation,
52+
cacheRead,
53+
totalInputTokens,
54+
outputTokens,
3355
)
3456
}
3557

@@ -40,17 +62,21 @@ export function calculateApiCostOpenAI(
4062
outputTokens: number,
4163
cacheCreationInputTokens?: number,
4264
cacheReadInputTokens?: number,
43-
): number {
65+
): ApiCostResult {
4466
const cacheCreationInputTokensNum = cacheCreationInputTokens || 0
4567
const cacheReadInputTokensNum = cacheReadInputTokens || 0
4668
const nonCachedInputTokens = Math.max(0, inputTokens - cacheCreationInputTokensNum - cacheReadInputTokensNum)
4769

70+
// For OpenAI: inputTokens ALREADY includes all tokens (cached + non-cached)
71+
// So we pass the original inputTokens as the total
4872
return calculateApiCostInternal(
4973
modelInfo,
5074
nonCachedInputTokens,
5175
outputTokens,
5276
cacheCreationInputTokensNum,
5377
cacheReadInputTokensNum,
78+
inputTokens,
79+
outputTokens,
5480
)
5581
}
5682

src/shared/getApiMetrics.ts

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,15 +80,12 @@ export function getApiMetrics(messages: ClineMessage[]) {
8080
if (message.type === "say" && message.say === "api_req_started" && message.text) {
8181
try {
8282
const parsedText: ParsedApiReqStartedTextType = JSON.parse(message.text)
83-
const { tokensIn, tokensOut, cacheWrites, cacheReads, apiProtocol } = parsedText
84-
85-
// Calculate context tokens based on API protocol.
86-
if (apiProtocol === "anthropic") {
87-
result.contextTokens = (tokensIn || 0) + (tokensOut || 0) + (cacheWrites || 0) + (cacheReads || 0)
88-
} else {
89-
// For OpenAI (or when protocol is not specified).
90-
result.contextTokens = (tokensIn || 0) + (tokensOut || 0)
91-
}
83+
const { tokensIn, tokensOut } = parsedText
84+
85+
// Since tokensIn now stores TOTAL input tokens (including cache tokens),
86+
// we no longer need to add cacheWrites and cacheReads separately.
87+
// This applies to both Anthropic and OpenAI protocols.
88+
result.contextTokens = (tokensIn || 0) + (tokensOut || 0)
9289
} catch (error) {
9390
console.error("Error parsing JSON:", error)
9491
continue

0 commit comments

Comments
 (0)