Skip to content

Commit c3c2fd9

Browse files
authored
Merge pull request #320 from RooVetGit/o1_developer_role
Update openai package and use developer role message for o1
2 parents 15513f4 + 6e90bcf commit c3c2fd9

File tree

4 files changed

+133
-21
lines changed

4 files changed

+133
-21
lines changed

package-lock.json

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@
227227
"isbinaryfile": "^5.0.2",
228228
"mammoth": "^1.8.0",
229229
"monaco-vscode-textmate-theme-converter": "^0.1.7",
230-
"openai": "^4.73.1",
230+
"openai": "^4.78.1",
231231
"os-name": "^6.0.0",
232232
"p-wait-for": "^5.0.2",
233233
"pdf-parse": "^1.1.1",

src/api/providers/__tests__/openai-native.test.ts

Lines changed: 121 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ jest.mock('openai', () => {
6060
describe('OpenAiNativeHandler', () => {
6161
let handler: OpenAiNativeHandler;
6262
let mockOptions: ApiHandlerOptions;
63+
const systemPrompt = 'You are a helpful assistant.';
64+
const messages: Anthropic.Messages.MessageParam[] = [
65+
{
66+
role: 'user',
67+
content: 'Hello!'
68+
}
69+
];
6370

6471
beforeEach(() => {
6572
mockOptions = {
@@ -86,14 +93,6 @@ describe('OpenAiNativeHandler', () => {
8693
});
8794

8895
describe('createMessage', () => {
89-
const systemPrompt = 'You are a helpful assistant.';
90-
const messages: Anthropic.Messages.MessageParam[] = [
91-
{
92-
role: 'user',
93-
content: 'Hello!'
94-
}
95-
];
96-
9796
it('should handle streaming responses', async () => {
9897
const stream = handler.createMessage(systemPrompt, messages);
9998
const chunks: any[] = [];
@@ -109,15 +108,126 @@ describe('OpenAiNativeHandler', () => {
109108

110109
it('should handle API errors', async () => {
111110
mockCreate.mockRejectedValueOnce(new Error('API Error'));
112-
113111
const stream = handler.createMessage(systemPrompt, messages);
114-
115112
await expect(async () => {
116113
for await (const chunk of stream) {
117114
// Should not reach here
118115
}
119116
}).rejects.toThrow('API Error');
120117
});
118+
119+
it('should handle missing content in response for o1 model', async () => {
120+
// Use o1 model which supports developer role
121+
handler = new OpenAiNativeHandler({
122+
...mockOptions,
123+
apiModelId: 'o1'
124+
});
125+
126+
mockCreate.mockResolvedValueOnce({
127+
choices: [{ message: { content: null } }],
128+
usage: {
129+
prompt_tokens: 0,
130+
completion_tokens: 0,
131+
total_tokens: 0
132+
}
133+
});
134+
135+
const generator = handler.createMessage(systemPrompt, messages);
136+
const results = [];
137+
for await (const result of generator) {
138+
results.push(result);
139+
}
140+
141+
expect(results).toEqual([
142+
{ type: 'text', text: '' },
143+
{ type: 'usage', inputTokens: 0, outputTokens: 0 }
144+
]);
145+
146+
// Verify developer role is used for system prompt with o1 model
147+
expect(mockCreate).toHaveBeenCalledWith({
148+
model: 'o1',
149+
messages: [
150+
{ role: 'developer', content: systemPrompt },
151+
{ role: 'user', content: 'Hello!' }
152+
]
153+
});
154+
});
155+
});
156+
157+
describe('streaming models', () => {
158+
beforeEach(() => {
159+
handler = new OpenAiNativeHandler({
160+
...mockOptions,
161+
apiModelId: 'gpt-4o',
162+
});
163+
});
164+
165+
it('should handle streaming response', async () => {
166+
const mockStream = [
167+
{ choices: [{ delta: { content: 'Hello' } }], usage: null },
168+
{ choices: [{ delta: { content: ' there' } }], usage: null },
169+
{ choices: [{ delta: { content: '!' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
170+
];
171+
172+
mockCreate.mockResolvedValueOnce(
173+
(async function* () {
174+
for (const chunk of mockStream) {
175+
yield chunk;
176+
}
177+
})()
178+
);
179+
180+
const generator = handler.createMessage(systemPrompt, messages);
181+
const results = [];
182+
for await (const result of generator) {
183+
results.push(result);
184+
}
185+
186+
expect(results).toEqual([
187+
{ type: 'text', text: 'Hello' },
188+
{ type: 'text', text: ' there' },
189+
{ type: 'text', text: '!' },
190+
{ type: 'usage', inputTokens: 10, outputTokens: 5 },
191+
]);
192+
193+
expect(mockCreate).toHaveBeenCalledWith({
194+
model: 'gpt-4o',
195+
temperature: 0,
196+
messages: [
197+
{ role: 'system', content: systemPrompt },
198+
{ role: 'user', content: 'Hello!' },
199+
],
200+
stream: true,
201+
stream_options: { include_usage: true },
202+
});
203+
});
204+
205+
it('should handle empty delta content', async () => {
206+
const mockStream = [
207+
{ choices: [{ delta: {} }], usage: null },
208+
{ choices: [{ delta: { content: null } }], usage: null },
209+
{ choices: [{ delta: { content: 'Hello' } }], usage: { prompt_tokens: 10, completion_tokens: 5 } },
210+
];
211+
212+
mockCreate.mockResolvedValueOnce(
213+
(async function* () {
214+
for (const chunk of mockStream) {
215+
yield chunk;
216+
}
217+
})()
218+
);
219+
220+
const generator = handler.createMessage(systemPrompt, messages);
221+
const results = [];
222+
for await (const result of generator) {
223+
results.push(result);
224+
}
225+
226+
expect(results).toEqual([
227+
{ type: 'text', text: 'Hello' },
228+
{ type: 'usage', inputTokens: 10, outputTokens: 5 },
229+
]);
230+
});
121231
});
122232

123233
describe('completePrompt', () => {
@@ -206,4 +316,4 @@ describe('OpenAiNativeHandler', () => {
206316
expect(modelInfo.info).toBeDefined();
207317
});
208318
});
209-
});
319+
});

src/api/providers/openai-native.ts

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,16 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
2323
}
2424

2525
async *createMessage(systemPrompt: string, messages: Anthropic.Messages.MessageParam[]): ApiStream {
26-
switch (this.getModel().id) {
26+
const modelId = this.getModel().id
27+
switch (modelId) {
2728
case "o1":
2829
case "o1-preview":
2930
case "o1-mini": {
30-
// o1 doesnt support streaming, non-1 temp, or system prompt
31+
// o1-preview and o1-mini don't support streaming, non-1 temp, or system prompt
32+
// o1 doesnt support streaming or non-1 temp but does support a developer prompt
3133
const response = await this.client.chat.completions.create({
32-
model: this.getModel().id,
33-
messages: [{ role: "user", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
34+
model: modelId,
35+
messages: [{ role: modelId === "o1" ? "developer" : "user", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
3436
})
3537
yield {
3638
type: "text",
@@ -93,7 +95,7 @@ export class OpenAiNativeHandler implements ApiHandler, SingleCompletionHandler
9395
case "o1":
9496
case "o1-preview":
9597
case "o1-mini":
96-
// o1 doesn't support non-1 temp or system prompt
98+
// o1 doesn't support non-1 temp
9799
requestOptions = {
98100
model: modelId,
99101
messages: [{ role: "user", content: prompt }]

0 commit comments

Comments
 (0)