Skip to content

Commit 05ee49e

Browse files
committed
Track cache tokens and cost correctly for OpenAI
1 parent 678f6f2 commit 05ee49e

File tree

1 file changed

+56
-33
lines changed

1 file changed

+56
-33
lines changed

src/api/providers/openai-native.ts

Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,16 @@ import {
1111
import { convertToOpenAiMessages } from "../transform/openai-format"
1212
import { ApiStream } from "../transform/stream"
1313
import { BaseProvider } from "./base-provider"
14+
import { calculateApiCostOpenAI } from "../../utils/cost"
1415

1516
const OPENAI_NATIVE_DEFAULT_TEMPERATURE = 0
1617

18+
// Define a type for the model object returned by getModel
19+
export type OpenAiNativeModel = {
20+
id: OpenAiNativeModelId
21+
info: ModelInfo
22+
}
23+
1724
export class OpenAiNativeHandler extends BaseProvider implements SingleCompletionHandler {
1825
protected options: ApiHandlerOptions
1926
private client: OpenAI
@@ -26,31 +33,31 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
2633
}
2734

2835
override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
29-
const modelId = this.getModel().id
36+
const model = this.getModel()
3037

31-
if (modelId.startsWith("o1")) {
32-
yield* this.handleO1FamilyMessage(modelId, systemPrompt, messages)
38+
if (model.id.startsWith("o1")) {
39+
yield* this.handleO1FamilyMessage(model, systemPrompt, messages)
3340
return
3441
}
3542

36-
if (modelId.startsWith("o3-mini")) {
37-
yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages)
43+
if (model.id.startsWith("o3-mini")) {
44+
yield* this.handleO3FamilyMessage(model, systemPrompt, messages)
3845
return
3946
}
4047

41-
yield* this.handleDefaultModelMessage(modelId, systemPrompt, messages)
48+
yield* this.handleDefaultModelMessage(model, systemPrompt, messages)
4249
}
4350

4451
private async *handleO1FamilyMessage(
45-
modelId: string,
52+
model: OpenAiNativeModel,
4653
systemPrompt: string,
4754
messages: Anthropic.Messages.MessageParam[],
4855
): ApiStream {
4956
// o1 supports developer prompt with formatting
5057
// o1-preview and o1-mini only support user messages
51-
const isOriginalO1 = modelId === "o1"
58+
const isOriginalO1 = model.id === "o1"
5259
const response = await this.client.chat.completions.create({
53-
model: modelId,
60+
model: model.id,
5461
messages: [
5562
{
5663
role: isOriginalO1 ? "developer" : "user",
@@ -62,11 +69,11 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
6269
stream_options: { include_usage: true },
6370
})
6471

65-
yield* this.handleStreamResponse(response)
72+
yield* this.handleStreamResponse(response, model)
6673
}
6774

6875
private async *handleO3FamilyMessage(
69-
modelId: string,
76+
model: OpenAiNativeModel,
7077
systemPrompt: string,
7178
messages: Anthropic.Messages.MessageParam[],
7279
): ApiStream {
@@ -84,23 +91,23 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
8491
reasoning_effort: this.getModel().info.reasoningEffort,
8592
})
8693

87-
yield* this.handleStreamResponse(stream)
94+
yield* this.handleStreamResponse(stream, model)
8895
}
8996

9097
private async *handleDefaultModelMessage(
91-
modelId: string,
98+
model: OpenAiNativeModel,
9299
systemPrompt: string,
93100
messages: Anthropic.Messages.MessageParam[],
94101
): ApiStream {
95102
const stream = await this.client.chat.completions.create({
96-
model: modelId,
103+
model: model.id,
97104
temperature: this.options.modelTemperature ?? OPENAI_NATIVE_DEFAULT_TEMPERATURE,
98105
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
99106
stream: true,
100107
stream_options: { include_usage: true },
101108
})
102109

103-
yield* this.handleStreamResponse(stream)
110+
yield* this.handleStreamResponse(stream, model)
104111
}
105112

106113
private async *yieldResponseData(response: OpenAI.Chat.Completions.ChatCompletion): ApiStream {
@@ -115,7 +122,10 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
115122
}
116123
}
117124

118-
private async *handleStreamResponse(stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>): ApiStream {
125+
private async *handleStreamResponse(
126+
stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>,
127+
model: OpenAiNativeModel,
128+
): ApiStream {
119129
for await (const chunk of stream) {
120130
const delta = chunk.choices[0]?.delta
121131
if (delta?.content) {
@@ -126,16 +136,29 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
126136
}
127137

128138
if (chunk.usage) {
129-
yield {
130-
type: "usage",
131-
inputTokens: chunk.usage.prompt_tokens || 0,
132-
outputTokens: chunk.usage.completion_tokens || 0,
133-
}
139+
yield* this.yieldUsage(model.info, chunk.usage)
134140
}
135141
}
136142
}
137143

138-
override getModel(): { id: OpenAiNativeModelId; info: ModelInfo } {
144+
private async *yieldUsage(info: ModelInfo, usage: OpenAI.Completions.CompletionUsage | undefined): ApiStream {
145+
const inputTokens = usage?.prompt_tokens || 0 // sum of cache hits and misses
146+
const outputTokens = usage?.completion_tokens || 0
147+
const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0
148+
const cacheWriteTokens = 0
149+
const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
150+
const nonCachedInputTokens = Math.max(0, inputTokens - cacheReadTokens - cacheWriteTokens)
151+
yield {
152+
type: "usage",
153+
inputTokens: nonCachedInputTokens,
154+
outputTokens: outputTokens,
155+
cacheWriteTokens: cacheWriteTokens,
156+
cacheReadTokens: cacheReadTokens,
157+
totalCost: totalCost,
158+
}
159+
}
160+
161+
override getModel(): OpenAiNativeModel {
139162
const modelId = this.options.apiModelId
140163
if (modelId && modelId in openAiNativeModels) {
141164
const id = modelId as OpenAiNativeModelId
@@ -146,15 +169,15 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
146169

147170
async completePrompt(prompt: string): Promise<string> {
148171
try {
149-
const modelId = this.getModel().id
172+
const model = this.getModel()
150173
let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
151174

152-
if (modelId.startsWith("o1")) {
153-
requestOptions = this.getO1CompletionOptions(modelId, prompt)
154-
} else if (modelId.startsWith("o3-mini")) {
155-
requestOptions = this.getO3CompletionOptions(modelId, prompt)
175+
if (model.id.startsWith("o1")) {
176+
requestOptions = this.getO1CompletionOptions(model, prompt)
177+
} else if (model.id.startsWith("o3-mini")) {
178+
requestOptions = this.getO3CompletionOptions(model, prompt)
156179
} else {
157-
requestOptions = this.getDefaultCompletionOptions(modelId, prompt)
180+
requestOptions = this.getDefaultCompletionOptions(model, prompt)
158181
}
159182

160183
const response = await this.client.chat.completions.create(requestOptions)
@@ -168,17 +191,17 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
168191
}
169192

170193
private getO1CompletionOptions(
171-
modelId: string,
194+
model: OpenAiNativeModel,
172195
prompt: string,
173196
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
174197
return {
175-
model: modelId,
198+
model: model.id,
176199
messages: [{ role: "user", content: prompt }],
177200
}
178201
}
179202

180203
private getO3CompletionOptions(
181-
modelId: string,
204+
model: OpenAiNativeModel,
182205
prompt: string,
183206
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
184207
return {
@@ -189,11 +212,11 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
189212
}
190213

191214
private getDefaultCompletionOptions(
192-
modelId: string,
215+
model: OpenAiNativeModel,
193216
prompt: string,
194217
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
195218
return {
196-
model: modelId,
219+
model: model.id,
197220
messages: [{ role: "user", content: prompt }],
198221
temperature: this.options.modelTemperature ?? OPENAI_NATIVE_DEFAULT_TEMPERATURE,
199222
}

0 commit comments

Comments
 (0)