Skip to content

Commit 36bbfd9

Browse files
authored
Merge pull request #1003 from RooVetGit/o1_streaming
Enable streaming for o1
2 parents 7eada8b + 5ffb145 commit 36bbfd9

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

src/api/providers/__tests__/openai-native.test.ts

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,20 @@ describe("OpenAiNativeHandler", () => {
130130
})
131131

132132
mockCreate.mockResolvedValueOnce({
133-
choices: [{ message: { content: null } }],
134-
usage: {
135-
prompt_tokens: 0,
136-
completion_tokens: 0,
137-
total_tokens: 0,
133+
[Symbol.asyncIterator]: async function* () {
134+
yield {
135+
choices: [
136+
{
137+
delta: { content: null },
138+
index: 0,
139+
},
140+
],
141+
usage: {
142+
prompt_tokens: 0,
143+
completion_tokens: 0,
144+
total_tokens: 0,
145+
},
146+
}
138147
},
139148
})
140149

@@ -144,10 +153,7 @@ describe("OpenAiNativeHandler", () => {
144153
results.push(result)
145154
}
146155

147-
expect(results).toEqual([
148-
{ type: "text", text: "" },
149-
{ type: "usage", inputTokens: 0, outputTokens: 0 },
150-
])
156+
expect(results).toEqual([{ type: "usage", inputTokens: 0, outputTokens: 0 }])
151157

152158
// Verify developer role is used for system prompt with o1 model
153159
expect(mockCreate).toHaveBeenCalledWith({
@@ -156,6 +162,8 @@ describe("OpenAiNativeHandler", () => {
156162
{ role: "developer", content: "Formatting re-enabled\n" + systemPrompt },
157163
{ role: "user", content: "Hello!" },
158164
],
165+
stream: true,
166+
stream_options: { include_usage: true },
159167
})
160168
})
161169

src/api/providers/openai-native.ts

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,11 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
5656
},
5757
...convertToOpenAiMessages(messages),
5858
],
59+
stream: true,
60+
stream_options: { include_usage: true },
5961
})
6062

61-
yield* this.yieldResponseData(response)
63+
yield* this.handleStreamResponse(response)
6264
}
6365

6466
private async *handleO3FamilyMessage(

0 commit comments

Comments
 (0)