Skip to content

Commit bfe2274

Browse files
AlexandruSmirnovAlexandruSmirnovmrubens
authored
Add max tokens checkbox option for OpenAI compatible provider (#4467)
Co-authored-by: AlexandruSmirnov <[email protected]> Co-authored-by: Matt Rubens <[email protected]>
1 parent 58888d5 commit bfe2274

File tree

22 files changed

+779
-39
lines changed

22 files changed

+779
-39
lines changed

src/api/providers/__tests__/openai.spec.ts

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { OpenAiHandler } from "../openai"
55
import { ApiHandlerOptions } from "../../../shared/api"
66
import { Anthropic } from "@anthropic-ai/sdk"
77
import OpenAI from "openai"
8+
import { openAiModelInfoSaneDefaults } from "@roo-code/types"
89

910
const mockCreate = vitest.fn()
1011

@@ -197,6 +198,113 @@ describe("OpenAiHandler", () => {
197198
const callArgs = mockCreate.mock.calls[0][0]
198199
expect(callArgs.reasoning_effort).toBeUndefined()
199200
})
201+
202+
it("should include max_tokens when includeMaxTokens is true", async () => {
203+
const optionsWithMaxTokens: ApiHandlerOptions = {
204+
...mockOptions,
205+
includeMaxTokens: true,
206+
openAiCustomModelInfo: {
207+
contextWindow: 128_000,
208+
maxTokens: 4096,
209+
supportsPromptCache: false,
210+
},
211+
}
212+
const handlerWithMaxTokens = new OpenAiHandler(optionsWithMaxTokens)
213+
const stream = handlerWithMaxTokens.createMessage(systemPrompt, messages)
214+
// Consume the stream to trigger the API call
215+
for await (const _chunk of stream) {
216+
}
217+
// Assert the mockCreate was called with max_tokens
218+
expect(mockCreate).toHaveBeenCalled()
219+
const callArgs = mockCreate.mock.calls[0][0]
220+
expect(callArgs.max_completion_tokens).toBe(4096)
221+
})
222+
223+
it("should not include max_tokens when includeMaxTokens is false", async () => {
224+
const optionsWithoutMaxTokens: ApiHandlerOptions = {
225+
...mockOptions,
226+
includeMaxTokens: false,
227+
openAiCustomModelInfo: {
228+
contextWindow: 128_000,
229+
maxTokens: 4096,
230+
supportsPromptCache: false,
231+
},
232+
}
233+
const handlerWithoutMaxTokens = new OpenAiHandler(optionsWithoutMaxTokens)
234+
const stream = handlerWithoutMaxTokens.createMessage(systemPrompt, messages)
235+
// Consume the stream to trigger the API call
236+
for await (const _chunk of stream) {
237+
}
238+
// Assert the mockCreate was called without max_tokens
239+
expect(mockCreate).toHaveBeenCalled()
240+
const callArgs = mockCreate.mock.calls[0][0]
241+
expect(callArgs.max_completion_tokens).toBeUndefined()
242+
})
243+
244+
it("should not include max_tokens when includeMaxTokens is undefined", async () => {
245+
const optionsWithUndefinedMaxTokens: ApiHandlerOptions = {
246+
...mockOptions,
247+
// includeMaxTokens is not set, should not include max_tokens
248+
openAiCustomModelInfo: {
249+
contextWindow: 128_000,
250+
maxTokens: 4096,
251+
supportsPromptCache: false,
252+
},
253+
}
254+
const handlerWithDefaultMaxTokens = new OpenAiHandler(optionsWithUndefinedMaxTokens)
255+
const stream = handlerWithDefaultMaxTokens.createMessage(systemPrompt, messages)
256+
// Consume the stream to trigger the API call
257+
for await (const _chunk of stream) {
258+
}
259+
// Assert the mockCreate was called without max_tokens
260+
expect(mockCreate).toHaveBeenCalled()
261+
const callArgs = mockCreate.mock.calls[0][0]
262+
expect(callArgs.max_completion_tokens).toBeUndefined()
263+
})
264+
265+
it("should use user-configured modelMaxTokens instead of model default maxTokens", async () => {
266+
const optionsWithUserMaxTokens: ApiHandlerOptions = {
267+
...mockOptions,
268+
includeMaxTokens: true,
269+
modelMaxTokens: 32000, // User-configured value
270+
openAiCustomModelInfo: {
271+
contextWindow: 128_000,
272+
maxTokens: 4096, // Model's default value (should not be used)
273+
supportsPromptCache: false,
274+
},
275+
}
276+
const handlerWithUserMaxTokens = new OpenAiHandler(optionsWithUserMaxTokens)
277+
const stream = handlerWithUserMaxTokens.createMessage(systemPrompt, messages)
278+
// Consume the stream to trigger the API call
279+
for await (const _chunk of stream) {
280+
}
281+
// Assert the mockCreate was called with user-configured modelMaxTokens (32000), not model default maxTokens (4096)
282+
expect(mockCreate).toHaveBeenCalled()
283+
const callArgs = mockCreate.mock.calls[0][0]
284+
expect(callArgs.max_completion_tokens).toBe(32000)
285+
})
286+
287+
it("should fallback to model default maxTokens when user modelMaxTokens is not set", async () => {
288+
const optionsWithoutUserMaxTokens: ApiHandlerOptions = {
289+
...mockOptions,
290+
includeMaxTokens: true,
291+
// modelMaxTokens is not set
292+
openAiCustomModelInfo: {
293+
contextWindow: 128_000,
294+
maxTokens: 4096, // Model's default value (should be used as fallback)
295+
supportsPromptCache: false,
296+
},
297+
}
298+
const handlerWithoutUserMaxTokens = new OpenAiHandler(optionsWithoutUserMaxTokens)
299+
const stream = handlerWithoutUserMaxTokens.createMessage(systemPrompt, messages)
300+
// Consume the stream to trigger the API call
301+
for await (const _chunk of stream) {
302+
}
303+
// Assert the mockCreate was called with model default maxTokens (4096) as fallback
304+
expect(mockCreate).toHaveBeenCalled()
305+
const callArgs = mockCreate.mock.calls[0][0]
306+
expect(callArgs.max_completion_tokens).toBe(4096)
307+
})
200308
})
201309

202310
describe("error handling", () => {
@@ -336,6 +444,10 @@ describe("OpenAiHandler", () => {
336444
},
337445
{ path: "/models/chat/completions" },
338446
)
447+
448+
// Verify max_tokens is NOT included when includeMaxTokens is not set
449+
const callArgs = mockCreate.mock.calls[0][0]
450+
expect(callArgs).not.toHaveProperty("max_completion_tokens")
339451
})
340452

341453
it("should handle non-streaming responses with Azure AI Inference Service", async () => {
@@ -378,6 +490,10 @@ describe("OpenAiHandler", () => {
378490
},
379491
{ path: "/models/chat/completions" },
380492
)
493+
494+
// Verify max_tokens is NOT included when includeMaxTokens is not set
495+
const callArgs = mockCreate.mock.calls[0][0]
496+
expect(callArgs).not.toHaveProperty("max_completion_tokens")
381497
})
382498

383499
it("should handle completePrompt with Azure AI Inference Service", async () => {
@@ -391,6 +507,10 @@ describe("OpenAiHandler", () => {
391507
},
392508
{ path: "/models/chat/completions" },
393509
)
510+
511+
// Verify max_tokens is NOT included when includeMaxTokens is not set
512+
const callArgs = mockCreate.mock.calls[0][0]
513+
expect(callArgs).not.toHaveProperty("max_completion_tokens")
394514
})
395515
})
396516

@@ -433,4 +553,225 @@ describe("OpenAiHandler", () => {
433553
expect(lastCall[0]).not.toHaveProperty("stream_options")
434554
})
435555
})
556+
557+
describe("O3 Family Models", () => {
558+
const o3Options = {
559+
...mockOptions,
560+
openAiModelId: "o3-mini",
561+
openAiCustomModelInfo: {
562+
contextWindow: 128_000,
563+
maxTokens: 65536,
564+
supportsPromptCache: false,
565+
reasoningEffort: "medium" as "low" | "medium" | "high",
566+
},
567+
}
568+
569+
it("should handle O3 model with streaming and include max_completion_tokens when includeMaxTokens is true", async () => {
570+
const o3Handler = new OpenAiHandler({
571+
...o3Options,
572+
includeMaxTokens: true,
573+
modelMaxTokens: 32000,
574+
modelTemperature: 0.5,
575+
})
576+
const systemPrompt = "You are a helpful assistant."
577+
const messages: Anthropic.Messages.MessageParam[] = [
578+
{
579+
role: "user",
580+
content: "Hello!",
581+
},
582+
]
583+
584+
const stream = o3Handler.createMessage(systemPrompt, messages)
585+
const chunks: any[] = []
586+
for await (const chunk of stream) {
587+
chunks.push(chunk)
588+
}
589+
590+
expect(mockCreate).toHaveBeenCalledWith(
591+
expect.objectContaining({
592+
model: "o3-mini",
593+
messages: [
594+
{
595+
role: "developer",
596+
content: "Formatting re-enabled\nYou are a helpful assistant.",
597+
},
598+
{ role: "user", content: "Hello!" },
599+
],
600+
stream: true,
601+
stream_options: { include_usage: true },
602+
reasoning_effort: "medium",
603+
temperature: 0.5,
604+
// O3 models do not support deprecated max_tokens but do support max_completion_tokens
605+
max_completion_tokens: 32000,
606+
}),
607+
{},
608+
)
609+
})
610+
611+
it("should handle O3 model with streaming and exclude max_tokens when includeMaxTokens is false", async () => {
612+
const o3Handler = new OpenAiHandler({
613+
...o3Options,
614+
includeMaxTokens: false,
615+
modelTemperature: 0.7,
616+
})
617+
const systemPrompt = "You are a helpful assistant."
618+
const messages: Anthropic.Messages.MessageParam[] = [
619+
{
620+
role: "user",
621+
content: "Hello!",
622+
},
623+
]
624+
625+
const stream = o3Handler.createMessage(systemPrompt, messages)
626+
const chunks: any[] = []
627+
for await (const chunk of stream) {
628+
chunks.push(chunk)
629+
}
630+
631+
expect(mockCreate).toHaveBeenCalledWith(
632+
expect.objectContaining({
633+
model: "o3-mini",
634+
messages: [
635+
{
636+
role: "developer",
637+
content: "Formatting re-enabled\nYou are a helpful assistant.",
638+
},
639+
{ role: "user", content: "Hello!" },
640+
],
641+
stream: true,
642+
stream_options: { include_usage: true },
643+
reasoning_effort: "medium",
644+
temperature: 0.7,
645+
}),
646+
{},
647+
)
648+
649+
// Verify max_tokens is NOT included
650+
const callArgs = mockCreate.mock.calls[0][0]
651+
expect(callArgs).not.toHaveProperty("max_completion_tokens")
652+
})
653+
654+
it("should handle O3 model non-streaming with reasoning_effort and max_completion_tokens when includeMaxTokens is true", async () => {
655+
const o3Handler = new OpenAiHandler({
656+
...o3Options,
657+
openAiStreamingEnabled: false,
658+
includeMaxTokens: true,
659+
modelTemperature: 0.3,
660+
})
661+
const systemPrompt = "You are a helpful assistant."
662+
const messages: Anthropic.Messages.MessageParam[] = [
663+
{
664+
role: "user",
665+
content: "Hello!",
666+
},
667+
]
668+
669+
const stream = o3Handler.createMessage(systemPrompt, messages)
670+
const chunks: any[] = []
671+
for await (const chunk of stream) {
672+
chunks.push(chunk)
673+
}
674+
675+
expect(mockCreate).toHaveBeenCalledWith(
676+
expect.objectContaining({
677+
model: "o3-mini",
678+
messages: [
679+
{
680+
role: "developer",
681+
content: "Formatting re-enabled\nYou are a helpful assistant.",
682+
},
683+
{ role: "user", content: "Hello!" },
684+
],
685+
reasoning_effort: "medium",
686+
temperature: 0.3,
687+
// O3 models do not support deprecated max_tokens but do support max_completion_tokens
688+
max_completion_tokens: 65536, // Using default maxTokens from o3Options
689+
}),
690+
{},
691+
)
692+
693+
// Verify stream is not set
694+
const callArgs = mockCreate.mock.calls[0][0]
695+
expect(callArgs).not.toHaveProperty("stream")
696+
})
697+
698+
it("should use default temperature of 0 when not specified for O3 models", async () => {
699+
const o3Handler = new OpenAiHandler({
700+
...o3Options,
701+
// No modelTemperature specified
702+
})
703+
const systemPrompt = "You are a helpful assistant."
704+
const messages: Anthropic.Messages.MessageParam[] = [
705+
{
706+
role: "user",
707+
content: "Hello!",
708+
},
709+
]
710+
711+
const stream = o3Handler.createMessage(systemPrompt, messages)
712+
await stream.next()
713+
714+
expect(mockCreate).toHaveBeenCalledWith(
715+
expect.objectContaining({
716+
temperature: 0, // Default temperature
717+
}),
718+
{},
719+
)
720+
})
721+
722+
it("should handle O3 model with Azure AI Inference Service respecting includeMaxTokens", async () => {
723+
const o3AzureHandler = new OpenAiHandler({
724+
...o3Options,
725+
openAiBaseUrl: "https://test.services.ai.azure.com",
726+
includeMaxTokens: false, // Should NOT include max_tokens
727+
})
728+
const systemPrompt = "You are a helpful assistant."
729+
const messages: Anthropic.Messages.MessageParam[] = [
730+
{
731+
role: "user",
732+
content: "Hello!",
733+
},
734+
]
735+
736+
const stream = o3AzureHandler.createMessage(systemPrompt, messages)
737+
await stream.next()
738+
739+
expect(mockCreate).toHaveBeenCalledWith(
740+
expect.objectContaining({
741+
model: "o3-mini",
742+
}),
743+
{ path: "/models/chat/completions" },
744+
)
745+
746+
// Verify max_tokens is NOT included when includeMaxTokens is false
747+
const callArgs = mockCreate.mock.calls[0][0]
748+
expect(callArgs).not.toHaveProperty("max_completion_tokens")
749+
})
750+
751+
it("should NOT include max_tokens for O3 model with Azure AI Inference Service even when includeMaxTokens is true", async () => {
752+
const o3AzureHandler = new OpenAiHandler({
753+
...o3Options,
754+
openAiBaseUrl: "https://test.services.ai.azure.com",
755+
includeMaxTokens: true, // Should include max_tokens
756+
})
757+
const systemPrompt = "You are a helpful assistant."
758+
const messages: Anthropic.Messages.MessageParam[] = [
759+
{
760+
role: "user",
761+
content: "Hello!",
762+
},
763+
]
764+
765+
const stream = o3AzureHandler.createMessage(systemPrompt, messages)
766+
await stream.next()
767+
768+
expect(mockCreate).toHaveBeenCalledWith(
769+
expect.objectContaining({
770+
model: "o3-mini",
771+
// O3 models do not support max_tokens
772+
}),
773+
{ path: "/models/chat/completions" },
774+
)
775+
})
776+
})
436777
})

0 commit comments

Comments
 (0)