Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 66 additions & 10 deletions src/api/providers/__tests__/openai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,10 @@ describe("OpenAiHandler", () => {
expect(callArgs.max_completion_tokens).toBeUndefined()
})

it("should not include max_tokens when includeMaxTokens is undefined", async () => {
it("should include max_completion_tokens when includeMaxTokens is undefined (default behavior)", async () => {
const optionsWithUndefinedMaxTokens: ApiHandlerOptions = {
...mockOptions,
// includeMaxTokens is not set, should not include max_tokens
// includeMaxTokens is not set, should default to including max_completion_tokens
openAiCustomModelInfo: {
contextWindow: 128_000,
maxTokens: 4096,
Expand All @@ -257,10 +257,10 @@ describe("OpenAiHandler", () => {
// Consume the stream to trigger the API call
for await (const _chunk of stream) {
}
// Assert the mockCreate was called without max_tokens
// Assert the mockCreate was called with max_completion_tokens (default behavior)
expect(mockCreate).toHaveBeenCalled()
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.max_completion_tokens).toBeUndefined()
expect(callArgs.max_completion_tokens).toBe(4096)
})

it("should use user-configured modelMaxTokens instead of model default maxTokens", async () => {
Expand Down Expand Up @@ -306,6 +306,54 @@ describe("OpenAiHandler", () => {
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs.max_completion_tokens).toBe(4096)
})

it("should include max_completion_tokens by default for OpenAI compatible providers", async () => {
const optionsForCompatibleProvider: ApiHandlerOptions = {
...mockOptions,
// includeMaxTokens is not set, simulating OpenAI compatible provider usage
openAiBaseUrl: "https://api.koboldcpp.example.com/v1",
openAiCustomModelInfo: {
contextWindow: 32_000,
maxTokens: 4096,
supportsPromptCache: false,
},
}
const compatibleHandler = new OpenAiHandler(optionsForCompatibleProvider)
const stream = compatibleHandler.createMessage(systemPrompt, messages)

const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

// Verify max_completion_tokens is included by default
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs).toHaveProperty("max_completion_tokens", 4096)
})

it("should respect includeMaxTokens=false even for OpenAI compatible providers", async () => {
const optionsWithExplicitFalse: ApiHandlerOptions = {
...mockOptions,
includeMaxTokens: false, // Explicitly set to false
openAiBaseUrl: "https://api.koboldcpp.example.com/v1",
openAiCustomModelInfo: {
contextWindow: 32_000,
maxTokens: 4096,
supportsPromptCache: false,
},
}
const handlerWithExplicitFalse = new OpenAiHandler(optionsWithExplicitFalse)
const stream = handlerWithExplicitFalse.createMessage(systemPrompt, messages)

const chunks = []
for await (const chunk of stream) {
chunks.push(chunk)
}

// Verify max_completion_tokens is NOT included when explicitly set to false
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs).not.toHaveProperty("max_completion_tokens")
})
})

describe("error handling", () => {
Expand Down Expand Up @@ -402,6 +450,11 @@ describe("OpenAiHandler", () => {
openAiBaseUrl: "https://test.services.ai.azure.com",
openAiModelId: "deepseek-v3",
azureApiVersion: "2024-05-01-preview",
openAiCustomModelInfo: {
contextWindow: 128_000,
maxTokens: 4096,
supportsPromptCache: false,
},
}

it("should initialize with Azure AI Inference Service configuration", () => {
Expand Down Expand Up @@ -442,13 +495,14 @@ describe("OpenAiHandler", () => {
stream: true,
stream_options: { include_usage: true },
temperature: 0,
max_completion_tokens: 4096,
},
{ path: "/models/chat/completions" },
)

// Verify max_tokens is NOT included when includeMaxTokens is not set
// Verify max_completion_tokens IS included when includeMaxTokens is not set (default behavior)
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs).not.toHaveProperty("max_completion_tokens")
expect(callArgs).toHaveProperty("max_completion_tokens")
})

it("should handle non-streaming responses with Azure AI Inference Service", async () => {
Expand Down Expand Up @@ -488,13 +542,14 @@ describe("OpenAiHandler", () => {
{ role: "user", content: systemPrompt },
{ role: "user", content: "Hello!" },
],
max_completion_tokens: 4096,
},
{ path: "/models/chat/completions" },
)

// Verify max_tokens is NOT included when includeMaxTokens is not set
// Verify max_completion_tokens IS included when includeMaxTokens is not set (default behavior)
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs).not.toHaveProperty("max_completion_tokens")
expect(callArgs).toHaveProperty("max_completion_tokens")
})

it("should handle completePrompt with Azure AI Inference Service", async () => {
Expand All @@ -505,13 +560,14 @@ describe("OpenAiHandler", () => {
{
model: azureOptions.openAiModelId,
messages: [{ role: "user", content: "Test prompt" }],
max_completion_tokens: 4096,
},
{ path: "/models/chat/completions" },
)

// Verify max_tokens is NOT included when includeMaxTokens is not set
// Verify max_completion_tokens IS included when includeMaxTokens is not set (default behavior)
const callArgs = mockCreate.mock.calls[0][0]
expect(callArgs).not.toHaveProperty("max_completion_tokens")
expect(callArgs).toHaveProperty("max_completion_tokens")
})
})

Expand Down
13 changes: 10 additions & 3 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,18 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
| OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming,
modelInfo: ModelInfo,
): void {
// Only add max_completion_tokens if includeMaxTokens is true
if (this.options.includeMaxTokens === true) {
// For OpenAI compatible providers, always include max_completion_tokens to prevent
// fallback to provider's default (which may be too small, e.g., koboldcpp's 512 tokens)
// Only add max_completion_tokens if includeMaxTokens is explicitly true OR if it's undefined
// (treating undefined as true for backward compatibility with OpenAI compatible providers)
if (this.options.includeMaxTokens !== false) {
// Use user-configured modelMaxTokens if available, otherwise fall back to model's default maxTokens
// Using max_completion_tokens as max_tokens is deprecated
requestOptions.max_completion_tokens = this.options.modelMaxTokens || modelInfo.maxTokens
const maxTokens = this.options.modelMaxTokens || modelInfo.maxTokens
// Only set max_completion_tokens if we have a valid positive value
if (maxTokens && maxTokens > 0) {
requestOptions.max_completion_tokens = maxTokens
}
}
}
}
Expand Down