Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 56 additions & 33 deletions src/api/providers/openai-native.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@ import {
import { convertToOpenAiMessages } from "../transform/openai-format"
import { ApiStream } from "../transform/stream"
import { BaseProvider } from "./base-provider"
import { calculateApiCostOpenAI } from "../../utils/cost"

const OPENAI_NATIVE_DEFAULT_TEMPERATURE = 0

// Define a type for the model object returned by getModel
export type OpenAiNativeModel = {
id: OpenAiNativeModelId
info: ModelInfo
}

export class OpenAiNativeHandler extends BaseProvider implements SingleCompletionHandler {
protected options: ApiHandlerOptions
private client: OpenAI
Expand All @@ -26,31 +33,31 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
}

override async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
const modelId = this.getModel().id
const model = this.getModel()

if (modelId.startsWith("o1")) {
yield* this.handleO1FamilyMessage(modelId, systemPrompt, messages)
if (model.id.startsWith("o1")) {
yield* this.handleO1FamilyMessage(model, systemPrompt, messages)
return
}

if (modelId.startsWith("o3-mini")) {
yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages)
if (model.id.startsWith("o3-mini")) {
yield* this.handleO3FamilyMessage(model, systemPrompt, messages)
return
}

yield* this.handleDefaultModelMessage(modelId, systemPrompt, messages)
yield* this.handleDefaultModelMessage(model, systemPrompt, messages)
}

private async *handleO1FamilyMessage(
modelId: string,
model: OpenAiNativeModel,
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
): ApiStream {
// o1 supports developer prompt with formatting
// o1-preview and o1-mini only support user messages
const isOriginalO1 = modelId === "o1"
const isOriginalO1 = model.id === "o1"
const response = await this.client.chat.completions.create({
model: modelId,
model: model.id,
messages: [
{
role: isOriginalO1 ? "developer" : "user",
Expand All @@ -62,11 +69,11 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
stream_options: { include_usage: true },
})

yield* this.handleStreamResponse(response)
yield* this.handleStreamResponse(response, model)
}

private async *handleO3FamilyMessage(
modelId: string,
model: OpenAiNativeModel,
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
): ApiStream {
Expand All @@ -84,23 +91,23 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
reasoning_effort: this.getModel().info.reasoningEffort,
})

yield* this.handleStreamResponse(stream)
yield* this.handleStreamResponse(stream, model)
}

private async *handleDefaultModelMessage(
modelId: string,
model: OpenAiNativeModel,
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
): ApiStream {
const stream = await this.client.chat.completions.create({
model: modelId,
model: model.id,
temperature: this.options.modelTemperature ?? OPENAI_NATIVE_DEFAULT_TEMPERATURE,
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
stream: true,
stream_options: { include_usage: true },
})

yield* this.handleStreamResponse(stream)
yield* this.handleStreamResponse(stream, model)
}

private async *yieldResponseData(response: OpenAI.Chat.Completions.ChatCompletion): ApiStream {
Expand All @@ -115,7 +122,10 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
}
}

private async *handleStreamResponse(stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>): ApiStream {
private async *handleStreamResponse(
stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>,
model: OpenAiNativeModel,
): ApiStream {
for await (const chunk of stream) {
const delta = chunk.choices[0]?.delta
if (delta?.content) {
Expand All @@ -126,16 +136,29 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
}

if (chunk.usage) {
yield {
type: "usage",
inputTokens: chunk.usage.prompt_tokens || 0,
outputTokens: chunk.usage.completion_tokens || 0,
}
yield* this.yieldUsage(model.info, chunk.usage)
}
}
}

override getModel(): { id: OpenAiNativeModelId; info: ModelInfo } {
private async *yieldUsage(info: ModelInfo, usage: OpenAI.Completions.CompletionUsage | undefined): ApiStream {
const inputTokens = usage?.prompt_tokens || 0 // sum of cache hits and misses
const outputTokens = usage?.completion_tokens || 0
const cacheReadTokens = usage?.prompt_tokens_details?.cached_tokens || 0
const cacheWriteTokens = 0
const totalCost = calculateApiCostOpenAI(info, inputTokens, outputTokens, cacheWriteTokens, cacheReadTokens)
const nonCachedInputTokens = Math.max(0, inputTokens - cacheReadTokens - cacheWriteTokens)
yield {
type: "usage",
inputTokens: nonCachedInputTokens,
outputTokens: outputTokens,
cacheWriteTokens: cacheWriteTokens,
cacheReadTokens: cacheReadTokens,
totalCost: totalCost,
}
}

override getModel(): OpenAiNativeModel {
const modelId = this.options.apiModelId
if (modelId && modelId in openAiNativeModels) {
const id = modelId as OpenAiNativeModelId
Expand All @@ -146,15 +169,15 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio

async completePrompt(prompt: string): Promise<string> {
try {
const modelId = this.getModel().id
const model = this.getModel()
let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming

if (modelId.startsWith("o1")) {
requestOptions = this.getO1CompletionOptions(modelId, prompt)
} else if (modelId.startsWith("o3-mini")) {
requestOptions = this.getO3CompletionOptions(modelId, prompt)
if (model.id.startsWith("o1")) {
requestOptions = this.getO1CompletionOptions(model, prompt)
} else if (model.id.startsWith("o3-mini")) {
requestOptions = this.getO3CompletionOptions(model, prompt)
} else {
requestOptions = this.getDefaultCompletionOptions(modelId, prompt)
requestOptions = this.getDefaultCompletionOptions(model, prompt)
}

const response = await this.client.chat.completions.create(requestOptions)
Expand All @@ -168,17 +191,17 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
}

private getO1CompletionOptions(
modelId: string,
model: OpenAiNativeModel,
prompt: string,
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
return {
model: modelId,
model: model.id,
messages: [{ role: "user", content: prompt }],
}
}

private getO3CompletionOptions(
modelId: string,
model: OpenAiNativeModel,
prompt: string,
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
return {
Expand All @@ -189,11 +212,11 @@ export class OpenAiNativeHandler extends BaseProvider implements SingleCompletio
}

private getDefaultCompletionOptions(
modelId: string,
model: OpenAiNativeModel,
prompt: string,
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
return {
model: modelId,
model: model.id,
messages: [{ role: "user", content: prompt }],
temperature: this.options.modelTemperature ?? OPENAI_NATIVE_DEFAULT_TEMPERATURE,
}
Expand Down
12 changes: 12 additions & 0 deletions src/shared/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 2,
outputPrice: 8,
cacheReadsPrice: 0.5,
},
"gpt-4.1-mini": {
maxTokens: 32_768,
Expand All @@ -762,6 +763,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 0.4,
outputPrice: 1.6,
cacheReadsPrice: 0.1,
},
"gpt-4.1-nano": {
maxTokens: 32_768,
Expand All @@ -770,6 +772,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 0.1,
outputPrice: 0.4,
cacheReadsPrice: 0.025,
},
"o3-mini": {
maxTokens: 100_000,
Expand All @@ -778,6 +781,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 1.1,
outputPrice: 4.4,
cacheReadsPrice: 0.55,
reasoningEffort: "medium",
},
"o3-mini-high": {
Expand All @@ -787,6 +791,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 1.1,
outputPrice: 4.4,
cacheReadsPrice: 0.55,
reasoningEffort: "high",
},
"o3-mini-low": {
Expand All @@ -796,6 +801,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 1.1,
outputPrice: 4.4,
cacheReadsPrice: 0.55,
reasoningEffort: "low",
},
o1: {
Expand All @@ -805,6 +811,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 15,
outputPrice: 60,
cacheReadsPrice: 7.5,
},
"o1-preview": {
maxTokens: 32_768,
Expand All @@ -813,6 +820,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 15,
outputPrice: 60,
cacheReadsPrice: 7.5,
},
"o1-mini": {
maxTokens: 65_536,
Expand All @@ -821,6 +829,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 1.1,
outputPrice: 4.4,
cacheReadsPrice: 0.55,
},
"gpt-4.5-preview": {
maxTokens: 16_384,
Expand All @@ -829,6 +838,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 75,
outputPrice: 150,
cacheReadsPrice: 37.5,
},
"gpt-4o": {
maxTokens: 16_384,
Expand All @@ -837,6 +847,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 2.5,
outputPrice: 10,
cacheReadsPrice: 1.25,
},
"gpt-4o-mini": {
maxTokens: 16_384,
Expand All @@ -845,6 +856,7 @@ export const openAiNativeModels = {
supportsPromptCache: true,
inputPrice: 0.15,
outputPrice: 0.6,
cacheReadsPrice: 0.075,
},
} as const satisfies Record<string, ModelInfo>

Expand Down
Loading