Skip to content

Commit 0b72b68

Browse files
author
AlexandruSmirnov
committed
fix: O3 model max_tokens support and code optimizations
- Fixed max_tokens support for O3 models in OpenAI provider - Refactored OpenAI provider to eliminate code duplication with addMaxTokensIfNeeded helper - Made Azure AI Inference Service respect the includeMaxTokens checkbox setting - Applied code optimizations to reduce redundancy - Added missing translations for includeMaxTokens in Catalan and German locales - Updated tests to cover new functionality
1 parent 508c245 commit 0b72b68

File tree

4 files changed

+273
-22
lines changed

4 files changed

+273
-22
lines changed

src/api/providers/__tests__/openai.spec.ts

Lines changed: 231 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,10 +441,13 @@ describe("OpenAiHandler", () => {
441441
stream: true,
442442
stream_options: { include_usage: true },
443443
temperature: 0,
444-
max_tokens: -1,
445444
},
446445
{ path: "/models/chat/completions" },
447446
)
447+
448+
// Verify max_tokens is NOT included when includeMaxTokens is not set
449+
const callArgs = mockCreate.mock.calls[0][0]
450+
expect(callArgs).not.toHaveProperty("max_tokens")
448451
})
449452

450453
it("should handle non-streaming responses with Azure AI Inference Service", async () => {
@@ -484,10 +487,13 @@ describe("OpenAiHandler", () => {
484487
{ role: "user", content: systemPrompt },
485488
{ role: "user", content: "Hello!" },
486489
],
487-
max_tokens: -1, // Default from openAiModelInfoSaneDefaults
488490
},
489491
{ path: "/models/chat/completions" },
490492
)
493+
494+
// Verify max_tokens is NOT included when includeMaxTokens is not set
495+
const callArgs = mockCreate.mock.calls[0][0]
496+
expect(callArgs).not.toHaveProperty("max_tokens")
491497
})
492498

493499
it("should handle completePrompt with Azure AI Inference Service", async () => {
@@ -498,10 +504,13 @@ describe("OpenAiHandler", () => {
498504
{
499505
model: azureOptions.openAiModelId,
500506
messages: [{ role: "user", content: "Test prompt" }],
501-
max_tokens: -1, // Default from openAiModelInfoSaneDefaults
502507
},
503508
{ path: "/models/chat/completions" },
504509
)
510+
511+
// Verify max_tokens is NOT included when includeMaxTokens is not set
512+
const callArgs = mockCreate.mock.calls[0][0]
513+
expect(callArgs).not.toHaveProperty("max_tokens")
505514
})
506515
})
507516

@@ -544,4 +553,223 @@ describe("OpenAiHandler", () => {
544553
expect(lastCall[0]).not.toHaveProperty("stream_options")
545554
})
546555
})
556+
557+
describe("O3 Family Models", () => {
558+
const o3Options = {
559+
...mockOptions,
560+
openAiModelId: "o3-mini",
561+
openAiCustomModelInfo: {
562+
contextWindow: 128_000,
563+
maxTokens: 65536,
564+
supportsPromptCache: false,
565+
reasoningEffort: "medium" as "low" | "medium" | "high",
566+
},
567+
}
568+
569+
it("should handle O3 model with streaming and include max_tokens when includeMaxTokens is true", async () => {
570+
const o3Handler = new OpenAiHandler({
571+
...o3Options,
572+
includeMaxTokens: true,
573+
modelMaxTokens: 32000,
574+
modelTemperature: 0.5,
575+
})
576+
const systemPrompt = "You are a helpful assistant."
577+
const messages: Anthropic.Messages.MessageParam[] = [
578+
{
579+
role: "user",
580+
content: "Hello!",
581+
},
582+
]
583+
584+
const stream = o3Handler.createMessage(systemPrompt, messages)
585+
const chunks: any[] = []
586+
for await (const chunk of stream) {
587+
chunks.push(chunk)
588+
}
589+
590+
expect(mockCreate).toHaveBeenCalledWith(
591+
expect.objectContaining({
592+
model: "o3-mini",
593+
messages: [
594+
{
595+
role: "developer",
596+
content: "Formatting re-enabled\nYou are a helpful assistant.",
597+
},
598+
{ role: "user", content: "Hello!" },
599+
],
600+
stream: true,
601+
stream_options: { include_usage: true },
602+
reasoning_effort: "medium",
603+
temperature: 0.5,
604+
max_tokens: 32000,
605+
}),
606+
{},
607+
)
608+
})
609+
610+
it("should handle O3 model with streaming and exclude max_tokens when includeMaxTokens is false", async () => {
611+
const o3Handler = new OpenAiHandler({
612+
...o3Options,
613+
includeMaxTokens: false,
614+
modelTemperature: 0.7,
615+
})
616+
const systemPrompt = "You are a helpful assistant."
617+
const messages: Anthropic.Messages.MessageParam[] = [
618+
{
619+
role: "user",
620+
content: "Hello!",
621+
},
622+
]
623+
624+
const stream = o3Handler.createMessage(systemPrompt, messages)
625+
const chunks: any[] = []
626+
for await (const chunk of stream) {
627+
chunks.push(chunk)
628+
}
629+
630+
expect(mockCreate).toHaveBeenCalledWith(
631+
expect.objectContaining({
632+
model: "o3-mini",
633+
messages: [
634+
{
635+
role: "developer",
636+
content: "Formatting re-enabled\nYou are a helpful assistant.",
637+
},
638+
{ role: "user", content: "Hello!" },
639+
],
640+
stream: true,
641+
stream_options: { include_usage: true },
642+
reasoning_effort: "medium",
643+
temperature: 0.7,
644+
}),
645+
{},
646+
)
647+
648+
// Verify max_tokens is NOT included
649+
const callArgs = mockCreate.mock.calls[0][0]
650+
expect(callArgs).not.toHaveProperty("max_tokens")
651+
})
652+
653+
it("should handle O3 model non-streaming with max_tokens and reasoning_effort", async () => {
654+
const o3Handler = new OpenAiHandler({
655+
...o3Options,
656+
openAiStreamingEnabled: false,
657+
includeMaxTokens: true,
658+
modelTemperature: 0.3,
659+
})
660+
const systemPrompt = "You are a helpful assistant."
661+
const messages: Anthropic.Messages.MessageParam[] = [
662+
{
663+
role: "user",
664+
content: "Hello!",
665+
},
666+
]
667+
668+
const stream = o3Handler.createMessage(systemPrompt, messages)
669+
const chunks: any[] = []
670+
for await (const chunk of stream) {
671+
chunks.push(chunk)
672+
}
673+
674+
expect(mockCreate).toHaveBeenCalledWith(
675+
expect.objectContaining({
676+
model: "o3-mini",
677+
messages: [
678+
{
679+
role: "developer",
680+
content: "Formatting re-enabled\nYou are a helpful assistant.",
681+
},
682+
{ role: "user", content: "Hello!" },
683+
],
684+
reasoning_effort: "medium",
685+
temperature: 0.3,
686+
max_tokens: 65536, // Falls back to model default
687+
}),
688+
{},
689+
)
690+
691+
// Verify stream is not set
692+
const callArgs = mockCreate.mock.calls[0][0]
693+
expect(callArgs).not.toHaveProperty("stream")
694+
})
695+
696+
it("should use default temperature of 0 when not specified for O3 models", async () => {
697+
const o3Handler = new OpenAiHandler({
698+
...o3Options,
699+
// No modelTemperature specified
700+
})
701+
const systemPrompt = "You are a helpful assistant."
702+
const messages: Anthropic.Messages.MessageParam[] = [
703+
{
704+
role: "user",
705+
content: "Hello!",
706+
},
707+
]
708+
709+
const stream = o3Handler.createMessage(systemPrompt, messages)
710+
await stream.next()
711+
712+
expect(mockCreate).toHaveBeenCalledWith(
713+
expect.objectContaining({
714+
temperature: 0, // Default temperature
715+
}),
716+
{},
717+
)
718+
})
719+
720+
it("should handle O3 model with Azure AI Inference Service respecting includeMaxTokens", async () => {
721+
const o3AzureHandler = new OpenAiHandler({
722+
...o3Options,
723+
openAiBaseUrl: "https://test.services.ai.azure.com",
724+
includeMaxTokens: false, // Should NOT include max_tokens
725+
})
726+
const systemPrompt = "You are a helpful assistant."
727+
const messages: Anthropic.Messages.MessageParam[] = [
728+
{
729+
role: "user",
730+
content: "Hello!",
731+
},
732+
]
733+
734+
const stream = o3AzureHandler.createMessage(systemPrompt, messages)
735+
await stream.next()
736+
737+
expect(mockCreate).toHaveBeenCalledWith(
738+
expect.objectContaining({
739+
model: "o3-mini",
740+
}),
741+
{ path: "/models/chat/completions" },
742+
)
743+
744+
// Verify max_tokens is NOT included when includeMaxTokens is false
745+
const callArgs = mockCreate.mock.calls[0][0]
746+
expect(callArgs).not.toHaveProperty("max_tokens")
747+
})
748+
749+
it("should include max_tokens for O3 model with Azure AI Inference Service when includeMaxTokens is true", async () => {
750+
const o3AzureHandler = new OpenAiHandler({
751+
...o3Options,
752+
openAiBaseUrl: "https://test.services.ai.azure.com",
753+
includeMaxTokens: true, // Should include max_tokens
754+
})
755+
const systemPrompt = "You are a helpful assistant."
756+
const messages: Anthropic.Messages.MessageParam[] = [
757+
{
758+
role: "user",
759+
content: "Hello!",
760+
},
761+
]
762+
763+
const stream = o3AzureHandler.createMessage(systemPrompt, messages)
764+
await stream.next()
765+
766+
expect(mockCreate).toHaveBeenCalledWith(
767+
expect.objectContaining({
768+
model: "o3-mini",
769+
max_tokens: 65536, // Included when includeMaxTokens is true
770+
}),
771+
{ path: "/models/chat/completions" },
772+
)
773+
})
774+
})
547775
})

src/api/providers/openai.ts

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
158158
...(reasoning && reasoning),
159159
}
160160

161-
// @TODO: Move this to the `getModelParams` function.
162-
// Add max_tokens if specified or if using Azure AI Inference Service
163-
if (this.options.includeMaxTokens === true || isAzureAiInference) {
164-
// Use user-configured modelMaxTokens if available, otherwise fall back to model's default maxTokens
165-
requestOptions.max_tokens = this.options.modelMaxTokens || modelInfo.maxTokens
166-
}
161+
// Add max_tokens if needed
162+
this.addMaxTokensIfNeeded(requestOptions, modelInfo, isAzureAiInference)
167163

168164
const stream = await this.client.chat.completions.create(
169165
requestOptions,
@@ -224,10 +220,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
224220
: [systemMessage, ...convertToOpenAiMessages(messages)],
225221
}
226222

227-
// Add max_tokens if specified or if using Azure AI Inference Service
228-
if (this.options.includeMaxTokens === true || isAzureAiInference) {
229-
requestOptions.max_tokens = this.options.modelMaxTokens || modelInfo.maxTokens
230-
}
223+
// Add max_tokens if needed
224+
this.addMaxTokensIfNeeded(requestOptions, modelInfo, isAzureAiInference)
231225

232226
const response = await this.client.chat.completions.create(
233227
requestOptions,
@@ -263,17 +257,16 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
263257
async completePrompt(prompt: string): Promise<string> {
264258
try {
265259
const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
266-
const modelInfo = this.getModel().info
260+
const model = this.getModel()
261+
const modelInfo = model.info
267262

268263
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
269-
model: this.getModel().id,
264+
model: model.id,
270265
messages: [{ role: "user", content: prompt }],
271266
}
272267

273-
// Add max_tokens if specified or if using Azure AI Inference Service
274-
if (this.options.includeMaxTokens === true || isAzureAiInference) {
275-
requestOptions.max_tokens = this.options.modelMaxTokens || modelInfo.maxTokens
276-
}
268+
// Add max_tokens if needed
269+
this.addMaxTokensIfNeeded(requestOptions, modelInfo, isAzureAiInference)
277270

278271
const response = await this.client.chat.completions.create(
279272
requestOptions,
@@ -312,9 +305,13 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
312305
],
313306
stream: true,
314307
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
315-
reasoning_effort: this.getModel().info.reasoningEffort,
308+
reasoning_effort: modelInfo.reasoningEffort,
309+
temperature: this.options.modelTemperature ?? 0,
316310
}
317311

312+
// Add max_tokens if needed
313+
this.addMaxTokensIfNeeded(requestOptions, modelInfo, methodIsAzureAiInference)
314+
318315
const stream = await this.client.chat.completions.create(
319316
requestOptions,
320317
methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
@@ -331,8 +328,13 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
331328
},
332329
...convertToOpenAiMessages(messages),
333330
],
331+
reasoning_effort: modelInfo.reasoningEffort,
332+
temperature: this.options.modelTemperature ?? 0,
334333
}
335334

335+
// Add max_tokens if needed
336+
this.addMaxTokensIfNeeded(requestOptions, modelInfo, methodIsAzureAiInference)
337+
336338
const response = await this.client.chat.completions.create(
337339
requestOptions,
338340
methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
@@ -383,6 +385,23 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
383385
const urlHost = this._getUrlHost(baseUrl)
384386
return urlHost.endsWith(".services.ai.azure.com")
385387
}
388+
389+
/**
390+
* Adds max_tokens to the request body if needed based on provider configuration
391+
*/
392+
private addMaxTokensIfNeeded(
393+
requestOptions:
394+
| OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming
395+
| OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming,
396+
modelInfo: ModelInfo,
397+
isAzureAiInference: boolean,
398+
): void {
399+
// Only add max_tokens if includeMaxTokens is true
400+
if (this.options.includeMaxTokens === true) {
401+
// Use user-configured modelMaxTokens if available, otherwise fall back to model's default maxTokens
402+
requestOptions.max_tokens = this.options.modelMaxTokens || modelInfo.maxTokens
403+
}
404+
}
386405
}
387406

388407
export async function getOpenAiModels(baseUrl?: string, apiKey?: string, openAiHeaders?: Record<string, string>) {

webview-ui/src/i18n/locales/ca/settings.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,5 +601,7 @@
601601
"labels": {
602602
"customArn": "ARN personalitzat",
603603
"useCustomArn": "Utilitza ARN personalitzat..."
604-
}
604+
},
605+
"includeMaxOutputTokens": "Incloure tokens màxims de sortida",
606+
"includeMaxOutputTokensDescription": "Enviar el paràmetre de tokens màxims de sortida a les sol·licituds API. Alguns proveïdors poden no admetre això."
605607
}

webview-ui/src/i18n/locales/de/settings.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,5 +601,7 @@
601601
"labels": {
602602
"customArn": "Benutzerdefinierte ARN",
603603
"useCustomArn": "Benutzerdefinierte ARN verwenden..."
604-
}
604+
},
605+
"includeMaxOutputTokens": "Maximale Ausgabe-Tokens einbeziehen",
606+
"includeMaxOutputTokensDescription": "Sende den Parameter für maximale Ausgabe-Tokens in API-Anfragen. Einige Anbieter unterstützen dies möglicherweise nicht."
605607
}

0 commit comments

Comments
 (0)