Skip to content
111 changes: 111 additions & 0 deletions src/api/providers/__tests__/openai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { OpenAiHandler } from "../openai"
import { ApiHandlerOptions } from "../../../shared/api"
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { openAiModelInfoSaneDefaults } from "@roo-code/types"

const mockCreate = vitest.fn()

Expand Down Expand Up @@ -197,6 +198,113 @@ describe("OpenAiHandler", () => {
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.reasoning_effort).toBeUndefined()
})

it("should include max_tokens when includeMaxTokens is true", async () => {
const optionsWithMaxTokens: ApiHandlerOptions = {
...mockOptions,
includeMaxTokens: true,
openAiCustomModelInfo: {
contextWindow: 128_000,
maxTokens: 4096,
supportsPromptCache: false,
},
}
const handlerWithMaxTokens = new OpenAiHandler(optionsWithMaxTokens)
const stream = handlerWithMaxTokens.createMessage(systemPrompt, messages)
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
}
// Assert the mockCreate was called with max_tokens
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.max_tokens).toBe(4096)
})

it("should not include max_tokens when includeMaxTokens is false", async () => {
const optionsWithoutMaxTokens: ApiHandlerOptions = {
...mockOptions,
includeMaxTokens: false,
openAiCustomModelInfo: {
contextWindow: 128_000,
maxTokens: 4096,
supportsPromptCache: false,
},
}
const handlerWithoutMaxTokens = new OpenAiHandler(optionsWithoutMaxTokens)
const stream = handlerWithoutMaxTokens.createMessage(systemPrompt, messages)
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
}
// Assert the mockCreate was called without max_tokens
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.max_tokens).toBeUndefined()
})

it("should not include max_tokens when includeMaxTokens is undefined", async () => {
const optionsWithUndefinedMaxTokens: ApiHandlerOptions = {
...mockOptions,
// includeMaxTokens is not set, should not include max_tokens
openAiCustomModelInfo: {
contextWindow: 128_000,
maxTokens: 4096,
supportsPromptCache: false,
},
}
const handlerWithDefaultMaxTokens = new OpenAiHandler(optionsWithUndefinedMaxTokens)
const stream = handlerWithDefaultMaxTokens.createMessage(systemPrompt, messages)
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
}
// Assert the mockCreate was called without max_tokens
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.max_tokens).toBeUndefined()
})

it("should use user-configured modelMaxTokens instead of model default maxTokens", async () => {
const optionsWithUserMaxTokens: ApiHandlerOptions = {
...mockOptions,
includeMaxTokens: true,
modelMaxTokens: 32000, // User-configured value
openAiCustomModelInfo: {
contextWindow: 128_000,
maxTokens: 4096, // Model's default value (should not be used)
supportsPromptCache: false,
},
}
const handlerWithUserMaxTokens = new OpenAiHandler(optionsWithUserMaxTokens)
const stream = handlerWithUserMaxTokens.createMessage(systemPrompt, messages)
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
}
// Assert the mockCreate was called with user-configured modelMaxTokens (32000), not model default maxTokens (4096)
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.max_tokens).toBe(32000)
})

it("should fallback to model default maxTokens when user modelMaxTokens is not set", async () => {
const optionsWithoutUserMaxTokens: ApiHandlerOptions = {
...mockOptions,
includeMaxTokens: true,
// modelMaxTokens is not set
openAiCustomModelInfo: {
contextWindow: 128_000,
maxTokens: 4096, // Model's default value (should be used as fallback)
supportsPromptCache: false,
},
}
const handlerWithoutUserMaxTokens = new OpenAiHandler(optionsWithoutUserMaxTokens)
const stream = handlerWithoutUserMaxTokens.createMessage(systemPrompt, messages)
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
}
// Assert the mockCreate was called with model default maxTokens (4096) as fallback
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.max_tokens).toBe(4096)
})
})

describe("error handling", () => {
Expand Down Expand Up @@ -333,6 +441,7 @@ describe("OpenAiHandler", () => {
stream: true,
stream_options: { include_usage: true },
temperature: 0,
max_tokens: -1,
},
{ path: "/models/chat/completions" },
)
Expand Down Expand Up @@ -375,6 +484,7 @@ describe("OpenAiHandler", () => {
{ role: "user", content: systemPrompt },
{ role: "user", content: "Hello!" },
],
max_tokens: -1, // Default from openAiModelInfoSaneDefaults
},
{ path: "/models/chat/completions" },
)
Expand All @@ -388,6 +498,7 @@ describe("OpenAiHandler", () => {
{
model: azureOptions.openAiModelId,
messages: [{ role: "user", content: "Test prompt" }],
max_tokens: -1, // Default from openAiModelInfoSaneDefaults
},
{ path: "/models/chat/completions" },
)
Expand Down
52 changes: 33 additions & 19 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
}

// @TODO: Move this to the `getModelParams` function.
if (this.options.includeMaxTokens) {
requestOptions.max_tokens = modelInfo.maxTokens
// Add max_tokens if specified or if using Azure AI Inference Service
if (this.options.includeMaxTokens === true || isAzureAiInference) {
// Use user-configured modelMaxTokens if available, otherwise fall back to model's default maxTokens
requestOptions.max_tokens = this.options.modelMaxTokens || modelInfo.maxTokens
}

const stream = await this.client.chat.completions.create(
Expand Down Expand Up @@ -222,6 +224,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
: [systemMessage, ...convertToOpenAiMessages(messages)],
}

// Add max_tokens if specified or if using Azure AI Inference Service
if (this.options.includeMaxTokens === true || isAzureAiInference) {
requestOptions.max_tokens = this.options.modelMaxTokens || modelInfo.maxTokens
}

const response = await this.client.chat.completions.create(
requestOptions,
this._isAzureAiInference(modelUrl) ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
Expand Down Expand Up @@ -256,12 +263,18 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
async completePrompt(prompt: string): Promise<string> {
try {
const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
const modelInfo = this.getModel().info

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = {
model: this.getModel().id,
messages: [{ role: "user", content: prompt }],
}

// Add max_tokens if specified or if using Azure AI Inference Service
if (this.options.includeMaxTokens === true || isAzureAiInference) {
requestOptions.max_tokens = this.options.modelMaxTokens || modelInfo.maxTokens
}

const response = await this.client.chat.completions.create(
requestOptions,
isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
Expand All @@ -282,25 +295,28 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
systemPrompt: string,
messages: Anthropic.Messages.MessageParam[],
): ApiStream {
if (this.options.openAiStreamingEnabled ?? true) {
const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
const modelInfo = this.getModel().info
const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)

if (this.options.openAiStreamingEnabled ?? true) {
const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId,
messages: [
{
role: "developer",
content: `Formatting re-enabled\n${systemPrompt}`,
},
...convertToOpenAiMessages(messages),
],
stream: true,
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
reasoning_effort: this.getModel().info.reasoningEffort,
}

const stream = await this.client.chat.completions.create(
{
model: modelId,
messages: [
{
role: "developer",
content: `Formatting re-enabled\n${systemPrompt}`,
},
...convertToOpenAiMessages(messages),
],
stream: true,
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
reasoning_effort: this.getModel().info.reasoningEffort,
},
requestOptions,
methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
)

Expand All @@ -317,8 +333,6 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
],
}

const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)

const response = await this.client.chat.completions.create(
requestOptions,
methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,16 @@ export const OpenAICompatible = ({
onChange={handleInputChange("openAiStreamingEnabled", noTransform)}>
{t("settings:modelInfo.enableStreaming")}
</Checkbox>
<div>
<Checkbox
checked={apiConfiguration?.includeMaxTokens ?? true}
onChange={handleInputChange("includeMaxTokens", noTransform)}>
{t("settings:includeMaxOutputTokens")}
</Checkbox>
<div className="text-sm text-vscode-descriptionForeground ml-6">
{t("settings:includeMaxOutputTokensDescription")}
</div>
</div>
<Checkbox
checked={apiConfiguration?.openAiUseAzure ?? false}
onChange={handleInputChange("openAiUseAzure", noTransform)}>
Expand Down
Loading
Loading