Skip to content

Commit 508c245

Browse files
author
AlexandruSmirnov
committed
Add max tokens checkbox option for OpenAI compatible provider
- Add checkbox control to enable/disable max tokens in API requests - Update OpenAI compatible provider UI with max tokens option - Add test coverage for the new max tokens functionality - Update localization files across all supported languages
1 parent 360b0e7 commit 508c245

File tree

19 files changed

+513
-34
lines changed

19 files changed

+513
-34
lines changed

src/api/providers/__tests__/openai.spec.ts

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { OpenAiHandler } from "../openai"
55
import { ApiHandlerOptions } from "../../../shared/api"
66
import { Anthropic } from "@anthropic-ai/sdk"
77
import OpenAI from "openai"
8+
import { openAiModelInfoSaneDefaults } from "@roo-code/types"
89

910
const mockCreate = vitest.fn()
1011

@@ -197,6 +198,113 @@ describe("OpenAiHandler", () => {
197198
const callArgs = mockCreate.mock.calls[0][0]
198199
expect(callArgs.reasoning_effort).toBeUndefined()
199200
})
201+
202+
it("should include max_tokens when includeMaxTokens is true", async () => {
203+
const optionsWithMaxTokens: ApiHandlerOptions = {
204+
...mockOptions,
205+
includeMaxTokens: true,
206+
openAiCustomModelInfo: {
207+
contextWindow: 128_000,
208+
maxTokens: 4096,
209+
supportsPromptCache: false,
210+
},
211+
}
212+
const handlerWithMaxTokens = new OpenAiHandler(optionsWithMaxTokens)
213+
const stream = handlerWithMaxTokens.createMessage(systemPrompt, messages)
214+
// Consume the stream to trigger the API call
215+
for await (const _chunk of stream) {
216+
}
217+
// Assert the mockCreate was called with max_tokens
218+
expect(mockCreate).toHaveBeenCalled()
219+
const callArgs = mockCreate.mock.calls[0][0]
220+
expect(callArgs.max_tokens).toBe(4096)
221+
})
222+
223+
it("should not include max_tokens when includeMaxTokens is false", async () => {
224+
const optionsWithoutMaxTokens: ApiHandlerOptions = {
225+
...mockOptions,
226+
includeMaxTokens: false,
227+
openAiCustomModelInfo: {
228+
contextWindow: 128_000,
229+
maxTokens: 4096,
230+
supportsPromptCache: false,
231+
},
232+
}
233+
const handlerWithoutMaxTokens = new OpenAiHandler(optionsWithoutMaxTokens)
234+
const stream = handlerWithoutMaxTokens.createMessage(systemPrompt, messages)
235+
// Consume the stream to trigger the API call
236+
for await (const _chunk of stream) {
237+
}
238+
// Assert the mockCreate was called without max_tokens
239+
expect(mockCreate).toHaveBeenCalled()
240+
const callArgs = mockCreate.mock.calls[0][0]
241+
expect(callArgs.max_tokens).toBeUndefined()
242+
})
243+
244+
it("should not include max_tokens when includeMaxTokens is undefined", async () => {
245+
const optionsWithUndefinedMaxTokens: ApiHandlerOptions = {
246+
...mockOptions,
247+
// includeMaxTokens is not set, should not include max_tokens
248+
openAiCustomModelInfo: {
249+
contextWindow: 128_000,
250+
maxTokens: 4096,
251+
supportsPromptCache: false,
252+
},
253+
}
254+
const handlerWithDefaultMaxTokens = new OpenAiHandler(optionsWithUndefinedMaxTokens)
255+
const stream = handlerWithDefaultMaxTokens.createMessage(systemPrompt, messages)
256+
// Consume the stream to trigger the API call
257+
for await (const _chunk of stream) {
258+
}
259+
// Assert the mockCreate was called without max_tokens
260+
expect(mockCreate).toHaveBeenCalled()
261+
const callArgs = mockCreate.mock.calls[0][0]
262+
expect(callArgs.max_tokens).toBeUndefined()
263+
})
264+
265+
it("should use user-configured modelMaxTokens instead of model default maxTokens", async () => {
266+
const optionsWithUserMaxTokens: ApiHandlerOptions = {
267+
...mockOptions,
268+
includeMaxTokens: true,
269+
modelMaxTokens: 32000, // User-configured value
270+
openAiCustomModelInfo: {
271+
contextWindow: 128_000,
272+
maxTokens: 4096, // Model's default value (should not be used)
273+
supportsPromptCache: false,
274+
},
275+
}
276+
const handlerWithUserMaxTokens = new OpenAiHandler(optionsWithUserMaxTokens)
277+
const stream = handlerWithUserMaxTokens.createMessage(systemPrompt, messages)
278+
// Consume the stream to trigger the API call
279+
for await (const _chunk of stream) {
280+
}
281+
// Assert the mockCreate was called with user-configured modelMaxTokens (32000), not model default maxTokens (4096)
282+
expect(mockCreate).toHaveBeenCalled()
283+
const callArgs = mockCreate.mock.calls[0][0]
284+
expect(callArgs.max_tokens).toBe(32000)
285+
})
286+
287+
it("should fallback to model default maxTokens when user modelMaxTokens is not set", async () => {
288+
const optionsWithoutUserMaxTokens: ApiHandlerOptions = {
289+
...mockOptions,
290+
includeMaxTokens: true,
291+
// modelMaxTokens is not set
292+
openAiCustomModelInfo: {
293+
contextWindow: 128_000,
294+
maxTokens: 4096, // Model's default value (should be used as fallback)
295+
supportsPromptCache: false,
296+
},
297+
}
298+
const handlerWithoutUserMaxTokens = new OpenAiHandler(optionsWithoutUserMaxTokens)
299+
const stream = handlerWithoutUserMaxTokens.createMessage(systemPrompt, messages)
300+
// Consume the stream to trigger the API call
301+
for await (const _chunk of stream) {
302+
}
303+
// Assert the mockCreate was called with model default maxTokens (4096) as fallback
304+
expect(mockCreate).toHaveBeenCalled()
305+
const callArgs = mockCreate.mock.calls[0][0]
306+
expect(callArgs.max_tokens).toBe(4096)
307+
})
200308
})
201309

202310
describe("error handling", () => {
@@ -333,6 +441,7 @@ describe("OpenAiHandler", () => {
333441
stream: true,
334442
stream_options: { include_usage: true },
335443
temperature: 0,
444+
max_tokens: -1,
336445
},
337446
{ path: "/models/chat/completions" },
338447
)
@@ -375,6 +484,7 @@ describe("OpenAiHandler", () => {
375484
{ role: "user", content: systemPrompt },
376485
{ role: "user", content: "Hello!" },
377486
],
487+
max_tokens: -1, // Default from openAiModelInfoSaneDefaults
378488
},
379489
{ path: "/models/chat/completions" },
380490
)
@@ -388,6 +498,7 @@ describe("OpenAiHandler", () => {
388498
{
389499
model: azureOptions.openAiModelId,
390500
messages: [{ role: "user", content: "Test prompt" }],
501+
max_tokens: -1, // Default from openAiModelInfoSaneDefaults
391502
},
392503
{ path: "/models/chat/completions" },
393504
)

src/api/providers/openai.ts

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
159159
}
160160

161161
// @TODO: Move this to the `getModelParams` function.
162-
if (this.options.includeMaxTokens) {
163-
requestOptions.max_tokens = modelInfo.maxTokens
162+
// Add max_tokens if specified or if using Azure AI Inference Service
163+
if (this.options.includeMaxTokens === true || isAzureAiInference) {
164+
// Use user-configured modelMaxTokens if available, otherwise fall back to model's default maxTokens
165+
requestOptions.max_tokens = this.options.modelMaxTokens || modelInfo.maxTokens
164166
}
165167

166168
const stream = await this.client.chat.completions.create(
@@ -222,6 +224,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
222224
: [systemMessage, ...convertToOpenAiMessages(messages)],
223225
}
224226

227+
// Add max_tokens if specified or if using Azure AI Inference Service
228+
if (this.options.includeMaxTokens === true || isAzureAiInference) {
229+
requestOptions.max_tokens = this.options.modelMaxTokens || modelInfo.maxTokens
230+
}
231+
225232
const response = await this.client.chat.completions.create(
226233
requestOptions,
227234
this._isAzureAiInference(modelUrl) ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
@@ -256,12 +263,18 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
256263
async completePrompt(prompt: string): Promise<string> {
257264
try {
258265
const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
266+
const modelInfo = this.getModel().info
259267

260268
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
261269
model: this.getModel().id,
262270
messages: [{ role: "user", content: prompt }],
263271
}
264272

273+
// Add max_tokens if specified or if using Azure AI Inference Service
274+
if (this.options.includeMaxTokens === true || isAzureAiInference) {
275+
requestOptions.max_tokens = this.options.modelMaxTokens || modelInfo.maxTokens
276+
}
277+
265278
const response = await this.client.chat.completions.create(
266279
requestOptions,
267280
isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
@@ -282,25 +295,28 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
282295
systemPrompt: string,
283296
messages: Anthropic.Messages.MessageParam[],
284297
): ApiStream {
285-
if (this.options.openAiStreamingEnabled ?? true) {
286-
const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
298+
const modelInfo = this.getModel().info
299+
const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
287300

301+
if (this.options.openAiStreamingEnabled ?? true) {
288302
const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)
289303

304+
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
305+
model: modelId,
306+
messages: [
307+
{
308+
role: "developer",
309+
content: `Formatting re-enabled\n${systemPrompt}`,
310+
},
311+
...convertToOpenAiMessages(messages),
312+
],
313+
stream: true,
314+
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
315+
reasoning_effort: this.getModel().info.reasoningEffort,
316+
}
317+
290318
const stream = await this.client.chat.completions.create(
291-
{
292-
model: modelId,
293-
messages: [
294-
{
295-
role: "developer",
296-
content: `Formatting re-enabled\n${systemPrompt}`,
297-
},
298-
...convertToOpenAiMessages(messages),
299-
],
300-
stream: true,
301-
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
302-
reasoning_effort: this.getModel().info.reasoningEffort,
303-
},
319+
requestOptions,
304320
methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
305321
)
306322

@@ -317,8 +333,6 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
317333
],
318334
}
319335

320-
const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
321-
322336
const response = await this.client.chat.completions.create(
323337
requestOptions,
324338
methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},

webview-ui/src/components/settings/providers/OpenAICompatible.tsx

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,16 @@ export const OpenAICompatible = ({
164164
onChange={handleInputChange("openAiStreamingEnabled", noTransform)}>
165165
{t("settings:modelInfo.enableStreaming")}
166166
</Checkbox>
167+
<div>
168+
<Checkbox
169+
checked={apiConfiguration?.includeMaxTokens ?? true}
170+
onChange={handleInputChange("includeMaxTokens", noTransform)}>
171+
{t("settings:includeMaxOutputTokens")}
172+
</Checkbox>
173+
<div className="text-sm text-vscode-descriptionForeground ml-6">
174+
{t("settings:includeMaxOutputTokensDescription")}
175+
</div>
176+
</div>
167177
<Checkbox
168178
checked={apiConfiguration?.openAiUseAzure ?? false}
169179
onChange={handleInputChange("openAiUseAzure", noTransform)}>

0 commit comments

Comments
 (0)