Skip to content

Commit df93293

Browse files
committed
Refactor openai-native, prepend string to developer instructions so that o1/o3 will use md
1 parent d78da19 commit df93293

File tree

2 files changed

+163
-63
lines changed

2 files changed

+163
-63
lines changed

src/api/providers/__tests__/openai-native.test.ts

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,35 @@ describe("OpenAiNativeHandler", () => {
153153
expect(mockCreate).toHaveBeenCalledWith({
154154
model: "o1",
155155
messages: [
156-
{ role: "developer", content: systemPrompt },
156+
{ role: "developer", content: "Formatting re-enabled\n" + systemPrompt },
157157
{ role: "user", content: "Hello!" },
158158
],
159159
})
160160
})
161+
162+
it("should handle o3-mini model family correctly", async () => {
163+
handler = new OpenAiNativeHandler({
164+
...mockOptions,
165+
apiModelId: "o3-mini",
166+
})
167+
168+
const stream = handler.createMessage(systemPrompt, messages)
169+
const chunks: any[] = []
170+
for await (const chunk of stream) {
171+
chunks.push(chunk)
172+
}
173+
174+
expect(mockCreate).toHaveBeenCalledWith({
175+
model: "o3-mini",
176+
messages: [
177+
{ role: "developer", content: "Formatting re-enabled\n" + systemPrompt },
178+
{ role: "user", content: "Hello!" },
179+
],
180+
stream: true,
181+
stream_options: { include_usage: true },
182+
reasoning_effort: "medium",
183+
})
184+
})
161185
})
162186

163187
describe("streaming models", () => {

src/api/providers/openai-native.ts

Lines changed: 138 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -24,57 +24,111 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
2424

2525
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
2626
const modelId = this.getModel().id
27-
switch (modelId) {
28-
case "o1":
29-
case "o1-preview":
30-
case "o1-mini": {
31-
// o1-preview and o1-mini don't support streaming, non-1 temp, or system prompt
32-
// o1 doesnt support streaming or non-1 temp but does support a developer prompt
33-
const response = await this.client.chat.completions.create({
34-
model: modelId,
35-
messages: [
36-
{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt },
37-
...convertToOpenAiMessages(messages),
38-
],
39-
})
27+
28+
if (modelId.startsWith("o1")) {
29+
yield* this.handleO1FamilyMessage(modelId, systemPrompt, messages)
30+
return
31+
}
32+
33+
if (modelId.startsWith("o3-mini")) {
34+
yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages)
35+
return
36+
}
37+
38+
yield* this.handleDefaultModelMessage(modelId, systemPrompt, messages)
39+
}
40+
41+
private async *handleO1FamilyMessage(
42+
modelId: string,
43+
systemPrompt: string,
44+
messages: Anthropic.Messages.MessageParam[]
45+
): ApiStream {
46+
// o1 supports developer prompt with formatting
47+
// o1-preview and o1-mini only support user messages
48+
const isOriginalO1 = modelId === "o1"
49+
const response = await this.client.chat.completions.create({
50+
model: modelId,
51+
messages: [
52+
{
53+
role: isOriginalO1 ? "developer" : "user",
54+
content: isOriginalO1 ? `Formatting re-enabled\n${systemPrompt}` : systemPrompt,
55+
},
56+
...convertToOpenAiMessages(messages),
57+
],
58+
})
59+
60+
yield* this.yieldResponseData(response)
61+
}
62+
63+
private async *handleO3FamilyMessage(
64+
modelId: string,
65+
systemPrompt: string,
66+
messages: Anthropic.Messages.MessageParam[]
67+
): ApiStream {
68+
const stream = await this.client.chat.completions.create({
69+
model: "o3-mini",
70+
messages: [
71+
{
72+
role: "developer",
73+
content: `Formatting re-enabled\n${systemPrompt}`,
74+
},
75+
...convertToOpenAiMessages(messages),
76+
],
77+
stream: true,
78+
stream_options: { include_usage: true },
79+
reasoning_effort: this.getModel().info.reasoningEffort,
80+
})
81+
82+
yield* this.handleStreamResponse(stream)
83+
}
84+
85+
private async *handleDefaultModelMessage(
86+
modelId: string,
87+
systemPrompt: string,
88+
messages: Anthropic.Messages.MessageParam[]
89+
): ApiStream {
90+
const stream = await this.client.chat.completions.create({
91+
model: modelId,
92+
temperature: 0,
93+
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
94+
stream: true,
95+
stream_options: { include_usage: true },
96+
})
97+
98+
yield* this.handleStreamResponse(stream)
99+
}
100+
101+
private async *yieldResponseData(
102+
response: OpenAI.Chat.Completions.ChatCompletion
103+
): ApiStream {
104+
yield {
105+
type: "text",
106+
text: response.choices[0]?.message.content || "",
107+
}
108+
yield {
109+
type: "usage",
110+
inputTokens: response.usage?.prompt_tokens || 0,
111+
outputTokens: response.usage?.completion_tokens || 0,
112+
}
113+
}
114+
115+
private async *handleStreamResponse(
116+
stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
117+
): ApiStream {
118+
for await (const chunk of stream) {
119+
const delta = chunk.choices[0]?.delta
120+
if (delta?.content) {
40121
yield {
41122
type: "text",
42-
text: response.choices[0]?.message.content || "",
123+
text: delta.content,
43124
}
125+
}
126+
127+
if (chunk.usage) {
44128
yield {
45129
type: "usage",
46-
inputTokens: response.usage?.prompt_tokens || 0,
47-
outputTokens: response.usage?.completion_tokens || 0,
48-
}
49-
break
50-
}
51-
default: {
52-
const stream = await this.client.chat.completions.create({
53-
model: this.getModel().id,
54-
// max_completion_tokens: this.getModel().info.maxTokens,
55-
temperature: 0,
56-
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
57-
stream: true,
58-
stream_options: { include_usage: true },
59-
})
60-
61-
for await (const chunk of stream) {
62-
const delta = chunk.choices[0]?.delta
63-
if (delta?.content) {
64-
yield {
65-
type: "text",
66-
text: delta.content,
67-
}
68-
}
69-
70-
// contains a null value except for the last chunk which contains the token usage statistics for the entire request
71-
if (chunk.usage) {
72-
yield {
73-
type: "usage",
74-
inputTokens: chunk.usage.prompt_tokens || 0,
75-
outputTokens: chunk.usage.completion_tokens || 0,
76-
}
77-
}
130+
inputTokens: chunk.usage.prompt_tokens || 0,
131+
outputTokens: chunk.usage.completion_tokens || 0,
78132
}
79133
}
80134
}
@@ -94,22 +148,12 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
94148
const modelId = this.getModel().id
95149
let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
96150

97-
switch (modelId) {
98-
case "o1":
99-
case "o1-preview":
100-
case "o1-mini":
101-
// o1 doesn't support non-1 temp
102-
requestOptions = {
103-
model: modelId,
104-
messages: [{ role: "user", content: prompt }],
105-
}
106-
break
107-
default:
108-
requestOptions = {
109-
model: modelId,
110-
messages: [{ role: "user", content: prompt }],
111-
temperature: 0,
112-
}
151+
if (modelId.startsWith("o1")) {
152+
requestOptions = this.getO1CompletionOptions(modelId, prompt)
153+
} else if (modelId.startsWith("o3-mini")) {
154+
requestOptions = this.getO3CompletionOptions(modelId, prompt)
155+
} else {
156+
requestOptions = this.getDefaultCompletionOptions(modelId, prompt)
113157
}
114158

115159
const response = await this.client.chat.completions.create(requestOptions)
@@ -121,4 +165,36 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
121165
throw error
122166
}
123167
}
168+
169+
private getO1CompletionOptions(
170+
modelId: string,
171+
prompt: string
172+
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
173+
return {
174+
model: modelId,
175+
messages: [{ role: "user", content: prompt }],
176+
}
177+
}
178+
179+
private getO3CompletionOptions(
180+
modelId: string,
181+
prompt: string
182+
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
183+
return {
184+
model: "o3-mini",
185+
messages: [{ role: "user", content: prompt }],
186+
reasoning_effort: this.getModel().info.reasoningEffort,
187+
}
188+
}
189+
190+
private getDefaultCompletionOptions(
191+
modelId: string,
192+
prompt: string
193+
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
194+
return {
195+
model: modelId,
196+
messages: [{ role: "user", content: prompt }],
197+
temperature: 0,
198+
}
199+
}
124200
}

0 commit comments

Comments
 (0)