diff --git a/src/api/providers/__tests__/openai-native.test.ts b/src/api/providers/__tests__/openai-native.test.ts index 5b60d46c368..ce5fb6c8a67 100644 --- a/src/api/providers/__tests__/openai-native.test.ts +++ b/src/api/providers/__tests__/openai-native.test.ts @@ -153,7 +153,12 @@ describe("OpenAiNativeHandler", () => { results.push(result) } - expect(results).toEqual([{ type: "usage", inputTokens: 0, outputTokens: 0 }]) + // Verify essential fields directly + expect(results.length).toBe(1) + expect(results[0].type).toBe("usage") + // Use type assertion to avoid TypeScript errors + expect((results[0] as any).inputTokens).toBe(0) + expect((results[0] as any).outputTokens).toBe(0) // Verify developer role is used for system prompt with o1 model expect(mockCreate).toHaveBeenCalledWith({ @@ -221,12 +226,18 @@ describe("OpenAiNativeHandler", () => { results.push(result) } - expect(results).toEqual([ - { type: "text", text: "Hello" }, - { type: "text", text: " there" }, - { type: "text", text: "!" }, - { type: "usage", inputTokens: 10, outputTokens: 5 }, - ]) + // Verify text responses individually + expect(results.length).toBe(4) + expect(results[0]).toMatchObject({ type: "text", text: "Hello" }) + expect(results[1]).toMatchObject({ type: "text", text: " there" }) + expect(results[2]).toMatchObject({ type: "text", text: "!" }) + + // Check usage data fields but use toBeCloseTo for floating point comparison + expect(results[3].type).toBe("usage") + // Use type assertion to avoid TypeScript errors + expect((results[3] as any).inputTokens).toBe(10) + expect((results[3] as any).outputTokens).toBe(5) + expect((results[3] as any).totalCost).toBeCloseTo(0.00006, 6) expect(mockCreate).toHaveBeenCalledWith({ model: "gpt-4.1", @@ -261,10 +272,16 @@ describe("OpenAiNativeHandler", () => { results.push(result) } - expect(results).toEqual([ - { type: "text", text: "Hello" }, - { type: "usage", inputTokens: 10, outputTokens: 5 }, - ]) + // Verify responses individually + expect(results.length).toBe(2) + expect(results[0]).toMatchObject({ type: "text", text: "Hello" }) + + // Check usage data fields but use toBeCloseTo for floating point comparison + expect(results[1].type).toBe("usage") + // Use type assertion to avoid TypeScript errors + expect((results[1] as any).inputTokens).toBe(10) + expect((results[1] as any).outputTokens).toBe(5) + expect((results[1] as any).totalCost).toBeCloseTo(0.00006, 6) }) }) diff --git a/src/api/providers/openai-native.ts b/src/api/providers/openai-native.ts index 1fe7ef2a861..91e52a2f29f 100644 --- a/src/api/providers/openai-native.ts +++ b/src/api/providers/openai-native.ts @@ -11,9 +11,16 @@ import { import { convertToOpenAiMessages } from "../transform/openai-format" import { ApiStream } from "../transform/stream" import { BaseProvider } from "./base-provider" +import { calculateApiCostOpenAI } from "../../utils/cost" const OPENAI_NATIVE_DEFAULT_TEMPERATURE = 0 +// Define a type for the model object returned by getModel +export type OpenAiNativeModel = { + id: OpenAiNativeModelId + info: ModelInfo +} + export class OpenAiNativeHandler extends BaseProvider implements SingleCompletionHandler { protected options: ApiHandlerOptions private client: OpenAI @@ -26,31 +33,31 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio } override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream { - const modelId = this.getModel().id + const model = this.getModel() - if (modelId.startsWith("o1")) { - yield* this.handleO1FamilyMessage(modelId, systemPrompt, messages) + if (model.id.startsWith("o1")) { + yield* this.handleO1FamilyMessage(model, systemPrompt, messages) return } - if (modelId.startsWith("o3-mini")) { - yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages) + if (model.id.startsWith("o3-mini")) { + yield* this.handleO3FamilyMessage(model, systemPrompt, messages) return } - yield* this.handleDefaultModelMessage(modelId, systemPrompt, messages) + yield* this.handleDefaultModelMessage(model, systemPrompt, messages) } private async *handleO1FamilyMessage( - modelId: string, + model: OpenAiNativeModel, systemPrompt: string, messages: Anthropic.Messages.MessageParam[], ): ApiStream { // o1 supports developer prompt with formatting // o1-preview and o1-mini only support user messages - const isOriginalO1 = modelId === "o1" + const isOriginalO1 = model.id === "o1" const response = await this.client.chat.completions.create({ - model: modelId, + model: model.id, messages: [ { role: isOriginalO1 ? "developer" : "user", @@ -62,11 +69,11 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio stream_options: { include_usage: true }, }) - yield* this.handleStreamResponse(response) + yield* this.handleStreamResponse(response, model) } private async *handleO3FamilyMessage( - modelId: string, + model: OpenAiNativeModel, systemPrompt: string, messages: Anthropic.Messages.MessageParam[], ): ApiStream { @@ -84,23 +91,23 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio reasoning_effort: this.getModel().info.reasoningEffort, }) - yield* this.handleStreamResponse(stream) + yield* this.handleStreamResponse(stream, model) } private async *handleDefaultModelMessage( - modelId: string, + model: OpenAiNativeModel, systemPrompt: string, messages: Anthropic.Messages.MessageParam[], ): ApiStream { const stream = await this.client.chat.completions.create({ - model: modelId, + model: model.id, temperature: this.options.modelTemperature ?? OPENAI_NATIVE_DEFAULT_TEMPERATURE, messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)], stream: true, stream_options: { include_usage: true }, }) - yield* this.handleStreamResponse(stream) + yield* this.handleStreamResponse(stream, model) } private async *yieldResponseData(response: OpenAI.Chat.Completions.ChatCompletion): ApiStream { @@ -115,7 +122,10 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio } } - private async *handleStreamResponse(stream: AsyncIterable): ApiStream { + private async *handleStreamResponse( + stream: AsyncIterable, + model: OpenAiNativeModel, + ): ApiStream { for await (const chunk of stream) { const delta = chunk.choices[0]?.delta if (delta?.content) { @@ -126,16 +136,29 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio } if (chunk.usage) { - yield { - type: "usage", - inputTokens: chunk.usage.prompt_tokens || 0, - outputTokens: chunk.usage.completion_tokens || 0, - } + yield* this.yieldUsage(model.info, chunk.usage) } } } - override getModel(): { id: OpenAiNativeModelId; info: ModelInfo } { + private async *yieldUsage(info: ModelInfo, usage: OpenAI.Completions.CompletionUsage | undefined): ApiStream { + const inputTokens = usage?.prompt_tokens || 0 // sum of cache hits and misses + const outputTokens = usage?.completion_tokens || 0 + const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0 + const cacheWriteTokens = 0 + const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens) + const nonCachedInputTokens = Math.max(0, inputTokens - cacheReadTokens - cacheWriteTokens) + yield { + type: "usage", + inputTokens: nonCachedInputTokens, + outputTokens: outputTokens, + cacheWriteTokens: cacheWriteTokens, + cacheReadTokens: cacheReadTokens, + totalCost: totalCost, + } + } + + override getModel(): OpenAiNativeModel { const modelId = this.options.apiModelId if (modelId && modelId in openAiNativeModels) { const id = modelId as OpenAiNativeModelId @@ -146,15 +169,15 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio async completePrompt(prompt: string): Promise { try { - const modelId = this.getModel().id + const model = this.getModel() let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming - if (modelId.startsWith("o1")) { - requestOptions = this.getO1CompletionOptions(modelId, prompt) - } else if (modelId.startsWith("o3-mini")) { - requestOptions = this.getO3CompletionOptions(modelId, prompt) + if (model.id.startsWith("o1")) { + requestOptions = this.getO1CompletionOptions(model, prompt) + } else if (model.id.startsWith("o3-mini")) { + requestOptions = this.getO3CompletionOptions(model, prompt) } else { - requestOptions = this.getDefaultCompletionOptions(modelId, prompt) + requestOptions = this.getDefaultCompletionOptions(model, prompt) } const response = await this.client.chat.completions.create(requestOptions) @@ -168,17 +191,17 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio } private getO1CompletionOptions( - modelId: string, + model: OpenAiNativeModel, prompt: string, ): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming { return { - model: modelId, + model: model.id, messages: [{ role: "user", content: prompt }], } } private getO3CompletionOptions( - modelId: string, + model: OpenAiNativeModel, prompt: string, ): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming { return { @@ -189,11 +212,11 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio } private getDefaultCompletionOptions( - modelId: string, + model: OpenAiNativeModel, prompt: string, ): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming { return { - model: modelId, + model: model.id, messages: [{ role: "user", content: prompt }], temperature: this.options.modelTemperature ?? OPENAI_NATIVE_DEFAULT_TEMPERATURE, } diff --git a/src/shared/api.ts b/src/shared/api.ts index 0d0706581b2..a262c12abb5 100644 --- a/src/shared/api.ts +++ b/src/shared/api.ts @@ -754,6 +754,7 @@ export const openAiNativeModels = { supportsPromptCache: true, inputPrice: 2, outputPrice: 8, + cacheReadsPrice: 0.5, }, "gpt-4.1-mini": { maxTokens: 32_768, @@ -762,6 +763,7 @@ export const openAiNativeModels = { supportsPromptCache: true, inputPrice: 0.4, outputPrice: 1.6, + cacheReadsPrice: 0.1, }, "gpt-4.1-nano": { maxTokens: 32_768, @@ -770,6 +772,7 @@ export const openAiNativeModels = { supportsPromptCache: true, inputPrice: 0.1, outputPrice: 0.4, + cacheReadsPrice: 0.025, }, "o3-mini": { maxTokens: 100_000, @@ -778,6 +781,7 @@ export const openAiNativeModels = { supportsPromptCache: true, inputPrice: 1.1, outputPrice: 4.4, + cacheReadsPrice: 0.55, reasoningEffort: "medium", }, "o3-mini-high": { @@ -787,6 +791,7 @@ export const openAiNativeModels = { supportsPromptCache: true, inputPrice: 1.1, outputPrice: 4.4, + cacheReadsPrice: 0.55, reasoningEffort: "high", }, "o3-mini-low": { @@ -796,6 +801,7 @@ export const openAiNativeModels = { supportsPromptCache: true, inputPrice: 1.1, outputPrice: 4.4, + cacheReadsPrice: 0.55, reasoningEffort: "low", }, o1: { @@ -805,6 +811,7 @@ export const openAiNativeModels = { supportsPromptCache: true, inputPrice: 15, outputPrice: 60, + cacheReadsPrice: 7.5, }, "o1-preview": { maxTokens: 32_768, @@ -813,6 +820,7 @@ export const openAiNativeModels = { supportsPromptCache: true, inputPrice: 15, outputPrice: 60, + cacheReadsPrice: 7.5, }, "o1-mini": { maxTokens: 65_536, @@ -821,6 +829,7 @@ export const openAiNativeModels = { supportsPromptCache: true, inputPrice: 1.1, outputPrice: 4.4, + cacheReadsPrice: 0.55, }, "gpt-4.5-preview": { maxTokens: 16_384, @@ -829,6 +838,7 @@ export const openAiNativeModels = { supportsPromptCache: true, inputPrice: 75, outputPrice: 150, + cacheReadsPrice: 37.5, }, "gpt-4o": { maxTokens: 16_384, @@ -837,6 +847,7 @@ export const openAiNativeModels = { supportsPromptCache: true, inputPrice: 2.5, outputPrice: 10, + cacheReadsPrice: 1.25, }, "gpt-4o-mini": { maxTokens: 16_384, @@ -845,6 +856,7 @@ export const openAiNativeModels = { supportsPromptCache: true, inputPrice: 0.15, outputPrice: 0.6, + cacheReadsPrice: 0.075, }, } as const satisfies Record