Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions src/api/providers/__tests__/openai-native.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,12 @@ describe("OpenAiNativeHandler", () => {
results.push(result)
}

expect(results).toEqual([{ type: "usage", inputTokens: 0, outputTokens: 0 }])
// Verify essential fields directly
expect(results.length).toBe(1)
expect(results[0].type).toBe("usage")
// Use type assertion to avoid TypeScript errors
expect((results[0] as any).inputTokens).toBe(0)
expect((results[0] as any).outputTokens).toBe(0)

// Verify developer role is used for system prompt with o1 model
expect(mockCreate).toHaveBeenCalledWith({
Expand Down Expand Up @@ -221,12 +226,18 @@ describe("OpenAiNativeHandler", () => {
results.push(result)
}

expect(results).toEqual([
{ type: "text", text: "Hello" },
{ type: "text", text: " there" },
{ type: "text", text: "!" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
// Verify text responses individually
expect(results.length).toBe(4)
expect(results[0]).toMatchObject({ type: "text", text: "Hello" })
expect(results[1]).toMatchObject({ type: "text", text: " there" })
expect(results[2]).toMatchObject({ type: "text", text: "!" })

// Check usage data fields but use toBeCloseTo for floating point comparison
expect(results[3].type).toBe("usage")
// Use type assertion to avoid TypeScript errors
expect((results[3] as any).inputTokens).toBe(10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider replacing the any cast with a more specific type or unknown to improve type safety in these assertions.

expect((results[3] as any).outputTokens).toBe(5)
expect((results[3] as any).totalCost).toBeCloseTo(0.00006, 6)

expect(mockCreate).toHaveBeenCalledWith({
model: "gpt-4.1",
Expand Down Expand Up @@ -261,10 +272,16 @@ describe("OpenAiNativeHandler", () => {
results.push(result)
}

expect(results).toEqual([
{ type: "text", text: "Hello" },
{ type: "usage", inputTokens: 10, outputTokens: 5 },
])
// Verify responses individually
expect(results.length).toBe(2)
expect(results[0]).toMatchObject({ type: "text", text: "Hello" })

// Check usage data fields but use toBeCloseTo for floating point comparison
expect(results[1].type).toBe("usage")
// Use type assertion to avoid TypeScript errors
expect((results[1] as any).inputTokens).toBe(10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider defining proper result types to eliminate repeated 'as any' assertions, enhancing maintainability.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Repeated type assertions for usage properties appear in multiple tests. Consider extracting these assertions into a helper function to reduce duplication and improve maintainability.

expect((results[1] as any).outputTokens).toBe(5)
expect((results[1] as any).totalCost).toBeCloseTo(0.00006, 6)
})
})

Expand Down
89 changes: 56 additions & 33 deletions src/api/providers/openai-native.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@ import {
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"
import { BaseProvider } from "./base-provider"
import { calculateApiCostOpenAI } from "../../utils/cost"

const OPENAI_NATIVE_DEFAULT_TEMPERATURE = 0

// Define a type for the model object returned by getModel
export type OpenAiNativeModel = {
id: OpenAiNativeModelId
info: ModelInfo
}

export class OpenAiNativeHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
private client: OpenAI
Expand All @@ -26,31 +33,31 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
}

override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const modelId = this.getModel().id
const model = this.getModel()

if (modelId.startsWith("o1")) {
yield* this.handleO1FamilyMessage(modelId, systemPrompt, messages)
if (model.id.startsWith("o1")) {
yield* this.handleO1FamilyMessage(model, systemPrompt, messages)
return
}

if (modelId.startsWith("o3-mini")) {
yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages)
if (model.id.startsWith("o3-mini")) {
yield* this.handleO3FamilyMessage(model, systemPrompt, messages)
return
}

yield* this.handleDefaultModelMessage(modelId, systemPrompt, messages)
yield* this.handleDefaultModelMessage(model, systemPrompt, messages)
}

private async *handleO1FamilyMessage(
modelId: string,
model: OpenAiNativeModel,
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
): ApiStream {
// o1 supports developer prompt with formatting
// o1-preview and o1-mini only support user messages
const isOriginalO1 = modelId === "o1"
const isOriginalO1 = model.id === "o1"
const response = await this.client.chat.completions.create({
model: modelId,
model: model.id,
messages: [
{
role: isOriginalO1 ? "developer" : "user",
Expand All @@ -62,11 +69,11 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
stream_options: { include_usage: true },
})

yield* this.handleStreamResponse(response)
yield* this.handleStreamResponse(response, model)
}

private async *handleO3FamilyMessage(
modelId: string,
model: OpenAiNativeModel,
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
): ApiStream {
Expand All @@ -84,23 +91,23 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
reasoning_effort: this.getModel().info.reasoningEffort,
})

yield* this.handleStreamResponse(stream)
yield* this.handleStreamResponse(stream, model)
}

private async *handleDefaultModelMessage(
modelId: string,
model: OpenAiNativeModel,
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
): ApiStream {
const stream = await this.client.chat.completions.create({
model: modelId,
model: model.id,
temperature: this.options.modelTemperature ?? OPENAI_NATIVE_DEFAULT_TEMPERATURE,
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
})

yield* this.handleStreamResponse(stream)
yield* this.handleStreamResponse(stream, model)
}

private async *yieldResponseData(response: OpenAI.Chat.Completions.ChatCompletion): ApiStream {
Expand All @@ -115,7 +122,10 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
}
}

private async *handleStreamResponse(stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>): ApiStream {
private async *handleStreamResponse(
stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>,
model: OpenAiNativeModel,
): ApiStream {
for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta
if (delta?.content) {
Expand All @@ -126,16 +136,29 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
}

if (chunk.usage) {
yield {
type: "usage",
inputTokens: chunk.usage.prompt_tokens || 0,
outputTokens: chunk.usage.completion_tokens || 0,
}
yield* this.yieldUsage(model.info, chunk.usage)
}
}
}

override getModel(): { id: OpenAiNativeModelId; info: ModelInfo } {
private async *yieldUsage(info: ModelInfo, usage: OpenAI.Completions.CompletionUsage | undefined): ApiStream {
const inputTokens = usage?.prompt_tokens || 0 // sum of cache hits and misses
const outputTokens = usage?.completion_tokens || 0
const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0
const cacheWriteTokens = 0
const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
const nonCachedInputTokens = Math.max(0, inputTokens - cacheReadTokens - cacheWriteTokens)
yield {
type: "usage",
inputTokens: nonCachedInputTokens,
outputTokens: outputTokens,
cacheWriteTokens: cacheWriteTokens,
cacheReadTokens: cacheReadTokens,
totalCost: totalCost,
}
}

override getModel(): OpenAiNativeModel {
const modelId = this.options.apiModelId
if (modelId && modelId in openAiNativeModels) {
const id = modelId as OpenAiNativeModelId
Expand All @@ -146,15 +169,15 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio

async completePrompt(prompt: string): Promise<string> {
try {
const modelId = this.getModel().id
const model = this.getModel()
let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming

if (modelId.startsWith("o1")) {
requestOptions = this.getO1CompletionOptions(modelId, prompt)
} else if (modelId.startsWith("o3-mini")) {
requestOptions = this.getO3CompletionOptions(modelId, prompt)
if (model.id.startsWith("o1")) {
requestOptions = this.getO1CompletionOptions(model, prompt)
} else if (model.id.startsWith("o3-mini")) {
requestOptions = this.getO3CompletionOptions(model, prompt)
} else {
requestOptions = this.getDefaultCompletionOptions(modelId, prompt)
requestOptions = this.getDefaultCompletionOptions(model, prompt)
}

const response = await this.client.chat.completions.create(requestOptions)
Expand All @@ -168,17 +191,17 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
}

private getO1CompletionOptions(
modelId: string,
model: OpenAiNativeModel,
prompt: string,
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
return {
model: modelId,
model: model.id,
messages: [{ role: "user", content: prompt }],
}
}

private getO3CompletionOptions(
modelId: string,
model: OpenAiNativeModel,
prompt: string,
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
return {
Expand All @@ -189,11 +212,11 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
}

private getDefaultCompletionOptions(
modelId: string,
model: OpenAiNativeModel,
prompt: string,
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
return {
model: modelId,
model: model.id,
messages: [{ role: "user", content: prompt }],
temperature: this.options.modelTemperature ?? OPENAI_NATIVE_DEFAULT_TEMPERATURE,
}
Expand Down
12 changes: 12 additions & 0 deletions src/shared/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 2,
outputPrice: 8,
cacheReadsPrice: 0.5,
},
"gpt-4.1-mini": {
maxTokens: 32_768,
Expand All @@ -762,6 +763,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 0.4,
outputPrice: 1.6,
cacheReadsPrice: 0.1,
},
"gpt-4.1-nano": {
maxTokens: 32_768,
Expand All @@ -770,6 +772,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 0.1,
outputPrice: 0.4,
cacheReadsPrice: 0.025,
},
"o3-mini": {
maxTokens: 100_000,
Expand All @@ -778,6 +781,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 1.1,
outputPrice: 4.4,
cacheReadsPrice: 0.55,
reasoningEffort: "medium",
},
"o3-mini-high": {
Expand All @@ -787,6 +791,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 1.1,
outputPrice: 4.4,
cacheReadsPrice: 0.55,
reasoningEffort: "high",
},
"o3-mini-low": {
Expand All @@ -796,6 +801,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 1.1,
outputPrice: 4.4,
cacheReadsPrice: 0.55,
reasoningEffort: "low",
},
o1: {
Expand All @@ -805,6 +811,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 15,
outputPrice: 60,
cacheReadsPrice: 7.5,
},
"o1-preview": {
maxTokens: 32_768,
Expand All @@ -813,6 +820,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 15,
outputPrice: 60,
cacheReadsPrice: 7.5,
},
"o1-mini": {
maxTokens: 65_536,
Expand All @@ -821,6 +829,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 1.1,
outputPrice: 4.4,
cacheReadsPrice: 0.55,
},
"gpt-4.5-preview": {
maxTokens: 16_384,
Expand All @@ -829,6 +838,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 75,
outputPrice: 150,
cacheReadsPrice: 37.5,
},
"gpt-4o": {
maxTokens: 16_384,
Expand All @@ -837,6 +847,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 2.5,
outputPrice: 10,
cacheReadsPrice: 1.25,
},
"gpt-4o-mini": {
maxTokens: 16_384,
Expand All @@ -845,6 +856,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 0.15,
outputPrice: 0.6,
cacheReadsPrice: 0.075,
},
} as const satisfies Record<string, ModelInfo>

Expand Down