Skip to content
341 changes: 341 additions & 0 deletions src/api/providers/__tests__/openai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { OpenAiHandler } from "../openai"
import { ApiHandlerOptions } from "../../../shared/api"
import { Anthropic } from "@anthropic-ai/sdk"
import OpenAI from "openai"
import { openAiModelInfoSaneDefaults } from "@roo-code/types"

const mockCreate = vitest.fn()

Expand Down Expand Up @@ -197,6 +198,113 @@ describe("OpenAiHandler", () => {
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.reasoning_effort).toBeUndefined()
})

it("should include max_tokens when includeMaxTokens is true", async () => {
const optionsWithMaxTokens: ApiHandlerOptions = {
...mockOptions,
includeMaxTokens: true,
openAiCustomModelInfo: {
contextWindow: 128_000,
maxTokens: 4096,
supportsPromptCache: false,
},
}
const handlerWithMaxTokens = new OpenAiHandler(optionsWithMaxTokens)
const stream = handlerWithMaxTokens.createMessage(systemPrompt, messages)
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
}
// Assert the mockCreate was called with max_tokens
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.max_completion_tokens).toBe(4096)
})

it("should not include max_tokens when includeMaxTokens is false", async () => {
const optionsWithoutMaxTokens: ApiHandlerOptions = {
...mockOptions,
includeMaxTokens: false,
openAiCustomModelInfo: {
contextWindow: 128_000,
maxTokens: 4096,
supportsPromptCache: false,
},
}
const handlerWithoutMaxTokens = new OpenAiHandler(optionsWithoutMaxTokens)
const stream = handlerWithoutMaxTokens.createMessage(systemPrompt, messages)
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
}
// Assert the mockCreate was called without max_tokens
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.max_completion_tokens).toBeUndefined()
})

it("should not include max_tokens when includeMaxTokens is undefined", async () => {
const optionsWithUndefinedMaxTokens: ApiHandlerOptions = {
...mockOptions,
// includeMaxTokens is not set, should not include max_tokens
openAiCustomModelInfo: {
contextWindow: 128_000,
maxTokens: 4096,
supportsPromptCache: false,
},
}
const handlerWithDefaultMaxTokens = new OpenAiHandler(optionsWithUndefinedMaxTokens)
const stream = handlerWithDefaultMaxTokens.createMessage(systemPrompt, messages)
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
}
// Assert the mockCreate was called without max_tokens
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.max_completion_tokens).toBeUndefined()
})

it("should use user-configured modelMaxTokens instead of model default maxTokens", async () => {
const optionsWithUserMaxTokens: ApiHandlerOptions = {
...mockOptions,
includeMaxTokens: true,
modelMaxTokens: 32000, // User-configured value
openAiCustomModelInfo: {
contextWindow: 128_000,
maxTokens: 4096, // Model's default value (should not be used)
supportsPromptCache: false,
},
}
const handlerWithUserMaxTokens = new OpenAiHandler(optionsWithUserMaxTokens)
const stream = handlerWithUserMaxTokens.createMessage(systemPrompt, messages)
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
}
// Assert the mockCreate was called with user-configured modelMaxTokens (32000), not model default maxTokens (4096)
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.max_completion_tokens).toBe(32000)
})

it("should fallback to model default maxTokens when user modelMaxTokens is not set", async () => {
const optionsWithoutUserMaxTokens: ApiHandlerOptions = {
...mockOptions,
includeMaxTokens: true,
// modelMaxTokens is not set
openAiCustomModelInfo: {
contextWindow: 128_000,
maxTokens: 4096, // Model's default value (should be used as fallback)
supportsPromptCache: false,
},
}
const handlerWithoutUserMaxTokens = new OpenAiHandler(optionsWithoutUserMaxTokens)
const stream = handlerWithoutUserMaxTokens.createMessage(systemPrompt, messages)
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
}
// Assert the mockCreate was called with model default maxTokens (4096) as fallback
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.max_completion_tokens).toBe(4096)
})
})

describe("error handling", () => {
Expand Down Expand Up @@ -336,6 +444,10 @@ describe("OpenAiHandler", () => {
},
{ path: "/models/chat/completions" },
)

// Verify max_tokens is NOT included when includeMaxTokens is not set
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs).not.toHaveProperty("max_completion_tokens")
})

it("should handle non-streaming responses with Azure AI Inference Service", async () => {
Expand Down Expand Up @@ -378,6 +490,10 @@ describe("OpenAiHandler", () => {
},
{ path: "/models/chat/completions" },
)

// Verify max_tokens is NOT included when includeMaxTokens is not set
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs).not.toHaveProperty("max_completion_tokens")
})

it("should handle completePrompt with Azure AI Inference Service", async () => {
Expand All @@ -391,6 +507,10 @@ describe("OpenAiHandler", () => {
},
{ path: "/models/chat/completions" },
)

// Verify max_tokens is NOT included when includeMaxTokens is not set
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs).not.toHaveProperty("max_completion_tokens")
})
})

Expand Down Expand Up @@ -433,4 +553,225 @@ describe("OpenAiHandler", () => {
expect(lastCall[0]).not.toHaveProperty("stream_options")
})
})

describe("O3 Family Models", () => {
const o3Options = {
...mockOptions,
openAiModelId: "o3-mini",
openAiCustomModelInfo: {
contextWindow: 128_000,
maxTokens: 65536,
supportsPromptCache: false,
reasoningEffort: "medium" as "low" | "medium" | "high",
},
}

it("should handle O3 model with streaming and include max_completion_tokens when includeMaxTokens is true", async () => {
const o3Handler = new OpenAiHandler({
...o3Options,
includeMaxTokens: true,
modelMaxTokens: 32000,
modelTemperature: 0.5,
})
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]

const stream = o3Handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: "o3-mini",
messages: [
{
role: "developer",
content: "Formatting re-enabled\nYou are a helpful assistant.",
},
{ role: "user", content: "Hello!" },
],
stream: true,
stream_options: { include_usage: true },
reasoning_effort: "medium",
temperature: 0.5,
// O3 models do not support deprecated max_tokens but do support max_completion_tokens
max_completion_tokens: 32000,
}),
{},
)
})

it("should handle O3 model with streaming and exclude max_tokens when includeMaxTokens is false", async () => {
const o3Handler = new OpenAiHandler({
...o3Options,
includeMaxTokens: false,
modelTemperature: 0.7,
})
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]

const stream = o3Handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: "o3-mini",
messages: [
{
role: "developer",
content: "Formatting re-enabled\nYou are a helpful assistant.",
},
{ role: "user", content: "Hello!" },
],
stream: true,
stream_options: { include_usage: true },
reasoning_effort: "medium",
temperature: 0.7,
}),
{},
)

// Verify max_tokens is NOT included
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs).not.toHaveProperty("max_completion_tokens")
})

it("should handle O3 model non-streaming with reasoning_effort and max_completion_tokens when includeMaxTokens is true", async () => {
const o3Handler = new OpenAiHandler({
...o3Options,
openAiStreamingEnabled: false,
includeMaxTokens: true,
modelTemperature: 0.3,
})
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]

const stream = o3Handler.createMessage(systemPrompt, messages)
const chunks: any[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: "o3-mini",
messages: [
{
role: "developer",
content: "Formatting re-enabled\nYou are a helpful assistant.",
},
{ role: "user", content: "Hello!" },
],
reasoning_effort: "medium",
temperature: 0.3,
// O3 models do not support deprecated max_tokens but do support max_completion_tokens
max_completion_tokens: 65536, // Using default maxTokens from o3Options
}),
{},
)

// Verify stream is not set
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs).not.toHaveProperty("stream")
})

it("should use default temperature of 0 when not specified for O3 models", async () => {
const o3Handler = new OpenAiHandler({
...o3Options,
// No modelTemperature specified
})
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]

const stream = o3Handler.createMessage(systemPrompt, messages)
await stream.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0, // Default temperature
}),
{},
)
})

it("should handle O3 model with Azure AI Inference Service respecting includeMaxTokens", async () => {
const o3AzureHandler = new OpenAiHandler({
...o3Options,
openAiBaseUrl: "https://test.services.ai.azure.com",
includeMaxTokens: false, // Should NOT include max_tokens
})
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]

const stream = o3AzureHandler.createMessage(systemPrompt, messages)
await stream.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: "o3-mini",
}),
{ path: "/models/chat/completions" },
)

// Verify max_tokens is NOT included when includeMaxTokens is false
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs).not.toHaveProperty("max_completion_tokens")
})

it("should NOT include max_tokens for O3 model with Azure AI Inference Service even when includeMaxTokens is true", async () => {
const o3AzureHandler = new OpenAiHandler({
...o3Options,
openAiBaseUrl: "https://test.services.ai.azure.com",
includeMaxTokens: true, // Should include max_tokens
})
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]

const stream = o3AzureHandler.createMessage(systemPrompt, messages)
await stream.next()

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: "o3-mini",
// O3 models do not support max_tokens
}),
{ path: "/models/chat/completions" },
)
})
})
})
Loading
Loading