From c04e01953a92ff467a8b6868b256a1c423d5ed51 Mon Sep 17 00:00:00 2001 From: Roo Code Date: Tue, 12 Aug 2025 09:02:03 +0000 Subject: [PATCH] fix: use max_completion_tokens for GPT-5 models in LiteLLM provider - GPT-5 models require max_completion_tokens instead of the deprecated max_tokens parameter - Added detection for GPT-5 model variants (gpt-5, gpt5, GPT-5, etc.) - Updated both createMessage and completePrompt methods to handle GPT-5 models - Added comprehensive tests for GPT-5 model handling Fixes #6979 --- src/api/providers/__tests__/lite-llm.spec.ts | 178 +++++++++++++++++-- src/api/providers/lite-llm.ts | 23 ++- 2 files changed, 187 insertions(+), 14 deletions(-) diff --git a/src/api/providers/__tests__/lite-llm.spec.ts b/src/api/providers/__tests__/lite-llm.spec.ts index 26ebbc35258..0056619e46c 100644 --- a/src/api/providers/__tests__/lite-llm.spec.ts +++ b/src/api/providers/__tests__/lite-llm.spec.ts @@ -10,15 +10,9 @@ import { litellmDefaultModelId, litellmDefaultModelInfo } from "@roo-code/types" vi.mock("vscode", () => ({})) // Mock OpenAI -vi.mock("openai", () => { - const mockStream = { - [Symbol.asyncIterator]: vi.fn(), - } - - const mockCreate = vi.fn().mockReturnValue({ - withResponse: vi.fn().mockResolvedValue({ data: mockStream }), - }) +const mockCreate = vi.fn() +vi.mock("openai", () => { return { default: vi.fn().mockImplementation(() => ({ chat: { @@ -35,6 +29,15 @@ vi.mock("../fetchers/modelCache", () => ({ getModels: vi.fn().mockImplementation(() => { return Promise.resolve({ [litellmDefaultModelId]: litellmDefaultModelInfo, + "gpt-5": { ...litellmDefaultModelInfo, maxTokens: 8192 }, + gpt5: { ...litellmDefaultModelInfo, maxTokens: 8192 }, + "GPT-5": { ...litellmDefaultModelInfo, maxTokens: 8192 }, + "gpt-5-turbo": { ...litellmDefaultModelInfo, maxTokens: 8192 }, + "gpt5-preview": { ...litellmDefaultModelInfo, maxTokens: 8192 }, + "gpt-4": { ...litellmDefaultModelInfo, maxTokens: 8192 }, + "claude-3-opus": { ...litellmDefaultModelInfo, maxTokens: 8192 }, + "llama-3": { ...litellmDefaultModelInfo, maxTokens: 8192 }, + "gpt-4-turbo": { ...litellmDefaultModelInfo, maxTokens: 8192 }, }) }), })) @@ -42,7 +45,6 @@ vi.mock("../fetchers/modelCache", () => ({ describe("LiteLLMHandler", () => { let handler: LiteLLMHandler let mockOptions: ApiHandlerOptions - let mockOpenAIClient: any beforeEach(() => { vi.clearAllMocks() @@ -52,7 +54,6 @@ describe("LiteLLMHandler", () => { litellmModelId: litellmDefaultModelId, } handler = new LiteLLMHandler(mockOptions) - mockOpenAIClient = new OpenAI() }) describe("prompt caching", () => { @@ -85,7 +86,7 @@ describe("LiteLLMHandler", () => { }, } - mockOpenAIClient.chat.completions.create.mockReturnValue({ + mockCreate.mockReturnValue({ withResponse: vi.fn().mockResolvedValue({ data: mockStream }), }) @@ -96,7 +97,7 @@ describe("LiteLLMHandler", () => { } // Verify that create was called with cache control headers - const createCall = mockOpenAIClient.chat.completions.create.mock.calls[0][0] + const createCall = mockCreate.mock.calls[0][0] // Check system message has cache control in the proper format expect(createCall.messages[0]).toMatchObject({ @@ -155,4 +156,157 @@ describe("LiteLLMHandler", () => { }) }) }) + + describe("GPT-5 model handling", () => { + it("should use max_completion_tokens instead of max_tokens for GPT-5 models", async () => { + const optionsWithGPT5: ApiHandlerOptions = { + ...mockOptions, + litellmModelId: "gpt-5", + } + handler = new LiteLLMHandler(optionsWithGPT5) + + const systemPrompt = "You are a helpful assistant" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Hello" }] + + // Mock the stream response + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + choices: [{ delta: { content: "Hello!" } }], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + }, + } + }, + } + + mockCreate.mockReturnValue({ + withResponse: vi.fn().mockResolvedValue({ data: mockStream }), + }) + + const generator = handler.createMessage(systemPrompt, messages) + const results = [] + for await (const chunk of generator) { + results.push(chunk) + } + + // Verify that create was called with max_completion_tokens instead of max_tokens + const createCall = mockCreate.mock.calls[0][0] + + // Should have max_completion_tokens, not max_tokens + expect(createCall.max_completion_tokens).toBeDefined() + expect(createCall.max_tokens).toBeUndefined() + }) + + it("should use max_completion_tokens for various GPT-5 model variations", async () => { + const gpt5Variations = ["gpt-5", "gpt5", "GPT-5", "gpt-5-turbo", "gpt5-preview"] + + for (const modelId of gpt5Variations) { + vi.clearAllMocks() + + const optionsWithGPT5: ApiHandlerOptions = { + ...mockOptions, + litellmModelId: modelId, + } + handler = new LiteLLMHandler(optionsWithGPT5) + + const systemPrompt = "You are a helpful assistant" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test" }] + + // Mock the stream response + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + choices: [{ delta: { content: "Response" } }], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + }, + } + }, + } + + mockCreate.mockReturnValue({ + withResponse: vi.fn().mockResolvedValue({ data: mockStream }), + }) + + const generator = handler.createMessage(systemPrompt, messages) + for await (const chunk of generator) { + // Consume the generator + } + + // Verify that create was called with max_completion_tokens for this model variation + const createCall = mockCreate.mock.calls[0][0] + + expect(createCall.max_completion_tokens).toBeDefined() + expect(createCall.max_tokens).toBeUndefined() + } + }) + + it("should still use max_tokens for non-GPT-5 models", async () => { + const nonGPT5Models = ["gpt-4", "claude-3-opus", "llama-3", "gpt-4-turbo"] + + for (const modelId of nonGPT5Models) { + vi.clearAllMocks() + + const options: ApiHandlerOptions = { + ...mockOptions, + litellmModelId: modelId, + } + handler = new LiteLLMHandler(options) + + const systemPrompt = "You are a helpful assistant" + const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test" }] + + // Mock the stream response + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { + choices: [{ delta: { content: "Response" } }], + usage: { + prompt_tokens: 10, + completion_tokens: 5, + }, + } + }, + } + + mockCreate.mockReturnValue({ + withResponse: vi.fn().mockResolvedValue({ data: mockStream }), + }) + + const generator = handler.createMessage(systemPrompt, messages) + for await (const chunk of generator) { + // Consume the generator + } + + // Verify that create was called with max_tokens for non-GPT-5 models + const createCall = mockCreate.mock.calls[0][0] + + expect(createCall.max_tokens).toBeDefined() + expect(createCall.max_completion_tokens).toBeUndefined() + } + }) + + it("should use max_completion_tokens in completePrompt for GPT-5 models", async () => { + const optionsWithGPT5: ApiHandlerOptions = { + ...mockOptions, + litellmModelId: "gpt-5", + } + handler = new LiteLLMHandler(optionsWithGPT5) + + mockCreate.mockResolvedValue({ + choices: [{ message: { content: "Test response" } }], + }) + + await handler.completePrompt("Test prompt") + + // Verify that create was called with max_completion_tokens + const createCall = mockCreate.mock.calls[0][0] + + expect(createCall.max_completion_tokens).toBeDefined() + expect(createCall.max_tokens).toBeUndefined() + }) + }) }) diff --git a/src/api/providers/lite-llm.ts b/src/api/providers/lite-llm.ts index 7cea7411feb..a26e22cbfbc 100644 --- a/src/api/providers/lite-llm.ts +++ b/src/api/providers/lite-llm.ts @@ -107,9 +107,11 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa // Required by some providers; others default to max tokens allowed let maxTokens: number | undefined = info.maxTokens ?? undefined + // Check if this is a GPT-5 model that requires max_completion_tokens instead of max_tokens + const isGPT5Model = modelId.toLowerCase().includes("gpt-5") || modelId.toLowerCase().includes("gpt5") + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { model: modelId, - max_tokens: maxTokens, messages: [systemMessage, ...enhancedMessages], stream: true, stream_options: { @@ -117,6 +119,14 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa }, } + // GPT-5 models require max_completion_tokens instead of the deprecated max_tokens parameter + if (isGPT5Model && maxTokens) { + // @ts-ignore - max_completion_tokens is not in the OpenAI types yet but is supported + requestOptions.max_completion_tokens = maxTokens + } else if (maxTokens) { + requestOptions.max_tokens = maxTokens + } + if (this.supportsTemperature(modelId)) { requestOptions.temperature = this.options.modelTemperature ?? 0 } @@ -179,6 +189,9 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa async completePrompt(prompt: string): Promise { const { id: modelId, info } = await this.fetchModel() + // Check if this is a GPT-5 model that requires max_completion_tokens instead of max_tokens + const isGPT5Model = modelId.toLowerCase().includes("gpt-5") || modelId.toLowerCase().includes("gpt5") + try { const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { model: modelId, @@ -189,7 +202,13 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa requestOptions.temperature = this.options.modelTemperature ?? 0 } - requestOptions.max_tokens = info.maxTokens + // GPT-5 models require max_completion_tokens instead of the deprecated max_tokens parameter + if (isGPT5Model && info.maxTokens) { + // @ts-ignore - max_completion_tokens is not in the OpenAI types yet but is supported + requestOptions.max_completion_tokens = info.maxTokens + } else if (info.maxTokens) { + requestOptions.max_tokens = info.maxTokens + } const response = await this.client.chat.completions.create(requestOptions) return response.choices[0]?.message.content || ""