Skip to content

Commit cdbc7c5

Browse files
authored
Merge pull request #725 from nissa-seru/fix-o3-formatting
Refactor openai-native, prepend string to developer instructions so t…
2 parents 730bf8d + 27976b7 commit cdbc7c5

File tree

2 files changed

+163
-104
lines changed

2 files changed

+163
-104
lines changed

src/api/providers/__tests__/openai-native.test.ts

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,35 @@ describe("OpenAiNativeHandler", () => {
153153
expect(mockCreate).toHaveBeenCalledWith({
154154
model: "o1",
155155
messages: [
156-
{ role: "developer", content: systemPrompt },
156+
{ role: "developer", content: "Formatting re-enabled\n" + systemPrompt },
157157
{ role: "user", content: "Hello!" },
158158
],
159159
})
160160
})
161+
162+
it("should handle o3-mini model family correctly", async () => {
163+
handler = new OpenAiNativeHandler({
164+
...mockOptions,
165+
apiModelId: "o3-mini",
166+
})
167+
168+
const stream = handler.createMessage(systemPrompt, messages)
169+
const chunks: any[] = []
170+
for await (const chunk of stream) {
171+
chunks.push(chunk)
172+
}
173+
174+
expect(mockCreate).toHaveBeenCalledWith({
175+
model: "o3-mini",
176+
messages: [
177+
{ role: "developer", content: "Formatting re-enabled\n" + systemPrompt },
178+
{ role: "user", content: "Hello!" },
179+
],
180+
stream: true,
181+
stream_options: { include_usage: true },
182+
reasoning_effort: "medium",
183+
})
184+
})
161185
})
162186

163187
describe("streaming models", () => {

src/api/providers/openai-native.ts

Lines changed: 138 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -24,88 +24,111 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
2424

2525
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
2626
const modelId = this.getModel().id
27-
switch (modelId) {
28-
case "o1":
29-
case "o1-preview":
30-
case "o1-mini": {
31-
// o1-preview and o1-mini don't support streaming, non-1 temp, or system prompt
32-
// o1 doesnt support streaming or non-1 temp but does support a developer prompt
33-
const response = await this.client.chat.completions.create({
34-
model: modelId,
35-
messages: [
36-
{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt },
37-
...convertToOpenAiMessages(messages),
38-
],
39-
})
27+
28+
if (modelId.startsWith("o1")) {
29+
yield* this.handleO1FamilyMessage(modelId, systemPrompt, messages)
30+
return
31+
}
32+
33+
if (modelId.startsWith("o3-mini")) {
34+
yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages)
35+
return
36+
}
37+
38+
yield* this.handleDefaultModelMessage(modelId, systemPrompt, messages)
39+
}
40+
41+
private async *handleO1FamilyMessage(
42+
modelId: string,
43+
systemPrompt: string,
44+
messages: Anthropic.Messages.MessageParam[]
45+
): ApiStream {
46+
// o1 supports developer prompt with formatting
47+
// o1-preview and o1-mini only support user messages
48+
const isOriginalO1 = modelId === "o1"
49+
const response = await this.client.chat.completions.create({
50+
model: modelId,
51+
messages: [
52+
{
53+
role: isOriginalO1 ? "developer" : "user",
54+
content: isOriginalO1 ? `Formatting re-enabled\n${systemPrompt}` : systemPrompt,
55+
},
56+
...convertToOpenAiMessages(messages),
57+
],
58+
})
59+
60+
yield* this.yieldResponseData(response)
61+
}
62+
63+
private async *handleO3FamilyMessage(
64+
modelId: string,
65+
systemPrompt: string,
66+
messages: Anthropic.Messages.MessageParam[]
67+
): ApiStream {
68+
const stream = await this.client.chat.completions.create({
69+
model: "o3-mini",
70+
messages: [
71+
{
72+
role: "developer",
73+
content: `Formatting re-enabled\n${systemPrompt}`,
74+
},
75+
...convertToOpenAiMessages(messages),
76+
],
77+
stream: true,
78+
stream_options: { include_usage: true },
79+
reasoning_effort: this.getModel().info.reasoningEffort,
80+
})
81+
82+
yield* this.handleStreamResponse(stream)
83+
}
84+
85+
private async *handleDefaultModelMessage(
86+
modelId: string,
87+
systemPrompt: string,
88+
messages: Anthropic.Messages.MessageParam[]
89+
): ApiStream {
90+
const stream = await this.client.chat.completions.create({
91+
model: modelId,
92+
temperature: 0,
93+
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
94+
stream: true,
95+
stream_options: { include_usage: true },
96+
})
97+
98+
yield* this.handleStreamResponse(stream)
99+
}
100+
101+
private async *yieldResponseData(
102+
response: OpenAI.Chat.Completions.ChatCompletion
103+
): ApiStream {
104+
yield {
105+
type: "text",
106+
text: response.choices[0]?.message.content || "",
107+
}
108+
yield {
109+
type: "usage",
110+
inputTokens: response.usage?.prompt_tokens || 0,
111+
outputTokens: response.usage?.completion_tokens || 0,
112+
}
113+
}
114+
115+
private async *handleStreamResponse(
116+
stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
117+
): ApiStream {
118+
for await (const chunk of stream) {
119+
const delta = chunk.choices[0]?.delta
120+
if (delta?.content) {
40121
yield {
41122
type: "text",
42-
text: response.choices[0]?.message.content || "",
123+
text: delta.content,
43124
}
125+
}
126+
127+
if (chunk.usage) {
44128
yield {
45129
type: "usage",
46-
inputTokens: response.usage?.prompt_tokens || 0,
47-
outputTokens: response.usage?.completion_tokens || 0,
48-
}
49-
break
50-
}
51-
case "o3-mini":
52-
case "o3-mini-low":
53-
case "o3-mini-high": {
54-
const stream = await this.client.chat.completions.create({
55-
model: "o3-mini",
56-
messages: [{ role: "developer", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
57-
stream: true,
58-
stream_options: { include_usage: true },
59-
reasoning_effort: this.getModel().info.reasoningEffort,
60-
})
61-
62-
for await (const chunk of stream) {
63-
const delta = chunk.choices[0]?.delta
64-
if (delta?.content) {
65-
yield {
66-
type: "text",
67-
text: delta.content,
68-
}
69-
}
70-
71-
// contains a null value except for the last chunk which contains the token usage statistics for the entire request
72-
if (chunk.usage) {
73-
yield {
74-
type: "usage",
75-
inputTokens: chunk.usage.prompt_tokens || 0,
76-
outputTokens: chunk.usage.completion_tokens || 0,
77-
}
78-
}
79-
}
80-
break
81-
}
82-
default: {
83-
const stream = await this.client.chat.completions.create({
84-
model: this.getModel().id,
85-
// max_completion_tokens: this.getModel().info.maxTokens,
86-
temperature: 0,
87-
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
88-
stream: true,
89-
stream_options: { include_usage: true },
90-
})
91-
92-
for await (const chunk of stream) {
93-
const delta = chunk.choices[0]?.delta
94-
if (delta?.content) {
95-
yield {
96-
type: "text",
97-
text: delta.content,
98-
}
99-
}
100-
101-
// contains a null value except for the last chunk which contains the token usage statistics for the entire request
102-
if (chunk.usage) {
103-
yield {
104-
type: "usage",
105-
inputTokens: chunk.usage.prompt_tokens || 0,
106-
outputTokens: chunk.usage.completion_tokens || 0,
107-
}
108-
}
130+
inputTokens: chunk.usage.prompt_tokens || 0,
131+
outputTokens: chunk.usage.completion_tokens || 0,
109132
}
110133
}
111134
}
@@ -125,32 +148,12 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
125148
const modelId = this.getModel().id
126149
let requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming
127150

128-
switch (modelId) {
129-
case "o1":
130-
case "o1-preview":
131-
case "o1-mini":
132-
// o1 doesn't support non-1 temp
133-
requestOptions = {
134-
model: modelId,
135-
messages: [{ role: "user", content: prompt }],
136-
}
137-
break
138-
case "o3-mini":
139-
case "o3-mini-low":
140-
case "o3-mini-high":
141-
// o3 doesn't support non-1 temp
142-
requestOptions = {
143-
model: "o3-mini",
144-
messages: [{ role: "user", content: prompt }],
145-
reasoning_effort: this.getModel().info.reasoningEffort,
146-
}
147-
break
148-
default:
149-
requestOptions = {
150-
model: modelId,
151-
messages: [{ role: "user", content: prompt }],
152-
temperature: 0,
153-
}
151+
if (modelId.startsWith("o1")) {
152+
requestOptions = this.getO1CompletionOptions(modelId, prompt)
153+
} else if (modelId.startsWith("o3-mini")) {
154+
requestOptions = this.getO3CompletionOptions(modelId, prompt)
155+
} else {
156+
requestOptions = this.getDefaultCompletionOptions(modelId, prompt)
154157
}
155158

156159
const response = await this.client.chat.completions.create(requestOptions)
@@ -162,4 +165,36 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
162165
throw error
163166
}
164167
}
168+
169+
private getO1CompletionOptions(
170+
modelId: string,
171+
prompt: string
172+
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
173+
return {
174+
model: modelId,
175+
messages: [{ role: "user", content: prompt }],
176+
}
177+
}
178+
179+
private getO3CompletionOptions(
180+
modelId: string,
181+
prompt: string
182+
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
183+
return {
184+
model: "o3-mini",
185+
messages: [{ role: "user", content: prompt }],
186+
reasoning_effort: this.getModel().info.reasoningEffort,
187+
}
188+
}
189+
190+
private getDefaultCompletionOptions(
191+
modelId: string,
192+
prompt: string
193+
): OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming {
194+
return {
195+
model: modelId,
196+
messages: [{ role: "user", content: prompt }],
197+
temperature: 0,
198+
}
199+
}
165200
}

0 commit comments

Comments
 (0)