Skip to content

Commit 1f205d4

Browse files
committed
fix: respect includeMaxTokens option in BaseOpenAiCompatibleProvider
- Modified BaseOpenAiCompatibleProvider to only include max_tokens parameter when includeMaxTokens option is true - This fixes issue #6936 where Kimi K2 model output was being truncated to 1024 tokens - Updated tests for all providers that extend BaseOpenAiCompatibleProvider (groq, fireworks, chutes, sambanova, zai) - Added new test cases to verify max_tokens is not included by default and is included when includeMaxTokens is true Fixes #6936
1 parent 76e5a72 commit 1f205d4

File tree

6 files changed

+258
-12
lines changed

6 files changed

+258
-12
lines changed

src/api/providers/__tests__/chutes.spec.ts

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,8 @@ describe("ChutesHandler", () => {
325325
)
326326
})
327327

328-
it("createMessage should pass correct parameters to Chutes client for non-DeepSeek models", async () => {
328+
it("createMessage should not include max_tokens by default for non-DeepSeek models", async () => {
329329
const modelId: ChutesModelId = "unsloth/Llama-3.3-70B-Instruct"
330-
const modelInfo = chutesModels[modelId]
331330
const handlerWithModel = new ChutesHandler({ apiModelId: modelId, chutesApiKey: "test-chutes-api-key" })
332331

333332
mockCreate.mockImplementationOnce(() => {
@@ -346,6 +345,48 @@ describe("ChutesHandler", () => {
346345
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
347346
await messageGenerator.next()
348347

348+
expect(mockCreate).toHaveBeenCalledWith(
349+
expect.objectContaining({
350+
model: modelId,
351+
temperature: 0.5,
352+
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
353+
stream: true,
354+
stream_options: { include_usage: true },
355+
}),
356+
)
357+
// Verify max_tokens is NOT included
358+
expect(mockCreate).toHaveBeenCalledWith(
359+
expect.not.objectContaining({
360+
max_tokens: expect.anything(),
361+
}),
362+
)
363+
})
364+
365+
it("createMessage should include max_tokens when includeMaxTokens is true for non-DeepSeek models", async () => {
366+
const modelId: ChutesModelId = "unsloth/Llama-3.3-70B-Instruct"
367+
const modelInfo = chutesModels[modelId]
368+
const handlerWithModel = new ChutesHandler({
369+
apiModelId: modelId,
370+
chutesApiKey: "test-chutes-api-key",
371+
includeMaxTokens: true,
372+
})
373+
374+
mockCreate.mockImplementationOnce(() => {
375+
return {
376+
[Symbol.asyncIterator]: () => ({
377+
async next() {
378+
return { done: true }
379+
},
380+
}),
381+
}
382+
})
383+
384+
const systemPrompt = "Test system prompt for Chutes"
385+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Chutes" }]
386+
387+
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
388+
await messageGenerator.next()
389+
349390
expect(mockCreate).toHaveBeenCalledWith(
350391
expect.objectContaining({
351392
model: modelId,

src/api/providers/__tests__/fireworks.spec.ts

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,53 @@ describe("FireworksHandler", () => {
324324
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
325325
})
326326

327-
it("createMessage should pass correct parameters to Fireworks client", async () => {
327+
it("createMessage should not include max_tokens by default", async () => {
328+
const modelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-instruct"
329+
const handlerWithModel = new FireworksHandler({
330+
apiModelId: modelId,
331+
fireworksApiKey: "test-fireworks-api-key",
332+
})
333+
334+
mockCreate.mockImplementationOnce(() => {
335+
return {
336+
[Symbol.asyncIterator]: () => ({
337+
async next() {
338+
return { done: true }
339+
},
340+
}),
341+
}
342+
})
343+
344+
const systemPrompt = "Test system prompt for Fireworks"
345+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Fireworks" }]
346+
347+
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
348+
await messageGenerator.next()
349+
350+
expect(mockCreate).toHaveBeenCalledWith(
351+
expect.objectContaining({
352+
model: modelId,
353+
temperature: 0.5,
354+
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
355+
stream: true,
356+
stream_options: { include_usage: true },
357+
}),
358+
)
359+
// Verify max_tokens is NOT included
360+
expect(mockCreate).toHaveBeenCalledWith(
361+
expect.not.objectContaining({
362+
max_tokens: expect.anything(),
363+
}),
364+
)
365+
})
366+
367+
it("createMessage should include max_tokens when includeMaxTokens is true", async () => {
328368
const modelId: FireworksModelId = "accounts/fireworks/models/kimi-k2-instruct"
329369
const modelInfo = fireworksModels[modelId]
330370
const handlerWithModel = new FireworksHandler({
331371
apiModelId: modelId,
332372
fireworksApiKey: "test-fireworks-api-key",
373+
includeMaxTokens: true,
333374
})
334375

335376
mockCreate.mockImplementationOnce(() => {

src/api/providers/__tests__/groq.spec.ts

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,8 @@ describe("GroqHandler", () => {
111111
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
112112
})
113113

114-
it("createMessage should pass correct parameters to Groq client", async () => {
114+
it("createMessage should not include max_tokens by default", async () => {
115115
const modelId: GroqModelId = "llama-3.1-8b-instant"
116-
const modelInfo = groqModels[modelId]
117116
const handlerWithModel = new GroqHandler({ apiModelId: modelId, groqApiKey: "test-groq-api-key" })
118117

119118
mockCreate.mockImplementationOnce(() => {
@@ -132,6 +131,48 @@ describe("GroqHandler", () => {
132131
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
133132
await messageGenerator.next()
134133

134+
expect(mockCreate).toHaveBeenCalledWith(
135+
expect.objectContaining({
136+
model: modelId,
137+
temperature: 0.5,
138+
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
139+
stream: true,
140+
stream_options: { include_usage: true },
141+
}),
142+
)
143+
// Verify max_tokens is NOT included
144+
expect(mockCreate).toHaveBeenCalledWith(
145+
expect.not.objectContaining({
146+
max_tokens: expect.anything(),
147+
}),
148+
)
149+
})
150+
151+
it("createMessage should include max_tokens when includeMaxTokens is true", async () => {
152+
const modelId: GroqModelId = "llama-3.1-8b-instant"
153+
const modelInfo = groqModels[modelId]
154+
const handlerWithModel = new GroqHandler({
155+
apiModelId: modelId,
156+
groqApiKey: "test-groq-api-key",
157+
includeMaxTokens: true,
158+
})
159+
160+
mockCreate.mockImplementationOnce(() => {
161+
return {
162+
[Symbol.asyncIterator]: () => ({
163+
async next() {
164+
return { done: true }
165+
},
166+
}),
167+
}
168+
})
169+
170+
const systemPrompt = "Test system prompt for Groq"
171+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Groq" }]
172+
173+
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
174+
await messageGenerator.next()
175+
135176
expect(mockCreate).toHaveBeenCalledWith(
136177
expect.objectContaining({
137178
model: modelId,
@@ -143,4 +184,42 @@ describe("GroqHandler", () => {
143184
}),
144185
)
145186
})
187+
188+
it("createMessage should use modelMaxTokens over default when includeMaxTokens is true", async () => {
189+
const modelId: GroqModelId = "llama-3.1-8b-instant"
190+
const customMaxTokens = 2048
191+
const handlerWithModel = new GroqHandler({
192+
apiModelId: modelId,
193+
groqApiKey: "test-groq-api-key",
194+
includeMaxTokens: true,
195+
modelMaxTokens: customMaxTokens,
196+
})
197+
198+
mockCreate.mockImplementationOnce(() => {
199+
return {
200+
[Symbol.asyncIterator]: () => ({
201+
async next() {
202+
return { done: true }
203+
},
204+
}),
205+
}
206+
})
207+
208+
const systemPrompt = "Test system prompt for Groq"
209+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Groq" }]
210+
211+
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
212+
await messageGenerator.next()
213+
214+
expect(mockCreate).toHaveBeenCalledWith(
215+
expect.objectContaining({
216+
model: modelId,
217+
max_tokens: customMaxTokens,
218+
temperature: 0.5,
219+
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
220+
stream: true,
221+
stream_options: { include_usage: true },
222+
}),
223+
)
224+
})
146225
})

src/api/providers/__tests__/sambanova.spec.ts

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,53 @@ describe("SambaNovaHandler", () => {
116116
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
117117
})
118118

119-
it("createMessage should pass correct parameters to SambaNova client", async () => {
119+
it("createMessage should not include max_tokens by default", async () => {
120+
const modelId: SambaNovaModelId = "Meta-Llama-3.3-70B-Instruct"
121+
const handlerWithModel = new SambaNovaHandler({
122+
apiModelId: modelId,
123+
sambaNovaApiKey: "test-sambanova-api-key",
124+
})
125+
126+
mockCreate.mockImplementationOnce(() => {
127+
return {
128+
[Symbol.asyncIterator]: () => ({
129+
async next() {
130+
return { done: true }
131+
},
132+
}),
133+
}
134+
})
135+
136+
const systemPrompt = "Test system prompt for SambaNova"
137+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for SambaNova" }]
138+
139+
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
140+
await messageGenerator.next()
141+
142+
expect(mockCreate).toHaveBeenCalledWith(
143+
expect.objectContaining({
144+
model: modelId,
145+
temperature: 0.7,
146+
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
147+
stream: true,
148+
stream_options: { include_usage: true },
149+
}),
150+
)
151+
// Verify max_tokens is NOT included
152+
expect(mockCreate).toHaveBeenCalledWith(
153+
expect.not.objectContaining({
154+
max_tokens: expect.anything(),
155+
}),
156+
)
157+
})
158+
159+
it("createMessage should include max_tokens when includeMaxTokens is true", async () => {
120160
const modelId: SambaNovaModelId = "Meta-Llama-3.3-70B-Instruct"
121161
const modelInfo = sambaNovaModels[modelId]
122162
const handlerWithModel = new SambaNovaHandler({
123163
apiModelId: modelId,
124164
sambaNovaApiKey: "test-sambanova-api-key",
165+
includeMaxTokens: true,
125166
})
126167

127168
mockCreate.mockImplementationOnce(() => {

src/api/providers/__tests__/zai.spec.ts

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,13 +191,55 @@ describe("ZAiHandler", () => {
191191
expect(firstChunk.value).toEqual({ type: "usage", inputTokens: 10, outputTokens: 20 })
192192
})
193193

194-
it("createMessage should pass correct parameters to Z AI client", async () => {
194+
it("createMessage should not include max_tokens by default", async () => {
195+
const modelId: InternationalZAiModelId = "glm-4.5"
196+
const handlerWithModel = new ZAiHandler({
197+
apiModelId: modelId,
198+
zaiApiKey: "test-zai-api-key",
199+
zaiApiLine: "international",
200+
})
201+
202+
mockCreate.mockImplementationOnce(() => {
203+
return {
204+
[Symbol.asyncIterator]: () => ({
205+
async next() {
206+
return { done: true }
207+
},
208+
}),
209+
}
210+
})
211+
212+
const systemPrompt = "Test system prompt for Z AI"
213+
const messages: Anthropic.Messages.MessageParam[] = [{ role: "user", content: "Test message for Z AI" }]
214+
215+
const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages)
216+
await messageGenerator.next()
217+
218+
expect(mockCreate).toHaveBeenCalledWith(
219+
expect.objectContaining({
220+
model: modelId,
221+
temperature: ZAI_DEFAULT_TEMPERATURE,
222+
messages: expect.arrayContaining([{ role: "system", content: systemPrompt }]),
223+
stream: true,
224+
stream_options: { include_usage: true },
225+
}),
226+
)
227+
// Verify max_tokens is NOT included
228+
expect(mockCreate).toHaveBeenCalledWith(
229+
expect.not.objectContaining({
230+
max_tokens: expect.anything(),
231+
}),
232+
)
233+
})
234+
235+
it("createMessage should include max_tokens when includeMaxTokens is true", async () => {
195236
const modelId: InternationalZAiModelId = "glm-4.5"
196237
const modelInfo = internationalZAiModels[modelId]
197238
const handlerWithModel = new ZAiHandler({
198239
apiModelId: modelId,
199240
zaiApiKey: "test-zai-api-key",
200241
zaiApiLine: "international",
242+
includeMaxTokens: true,
201243
})
202244

203245
mockCreate.mockImplementationOnce(() => {

src/api/providers/base-openai-compatible-provider.ts

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,22 +67,24 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
6767
messages: Anthropic.Messages.MessageParam[],
6868
metadata?: ApiHandlerCreateMessageMetadata,
6969
): ApiStream {
70-
const {
71-
id: model,
72-
info: { maxTokens: max_tokens },
73-
} = this.getModel()
70+
const { id: model, info } = this.getModel()
7471

7572
const temperature = this.options.modelTemperature ?? this.defaultTemperature
7673

7774
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
7875
model,
79-
max_tokens,
8076
temperature,
8177
messages: [{ role: "system", content: systemPrompt }, ...convertToOpenAiMessages(messages)],
8278
stream: true,
8379
stream_options: { include_usage: true },
8480
}
8581

82+
// Only add max_tokens if includeMaxTokens is true
83+
if (this.options.includeMaxTokens === true) {
84+
// Use user-configured modelMaxTokens if available, otherwise fall back to model's default maxTokens
85+
params.max_tokens = this.options.modelMaxTokens || info.maxTokens
86+
}
87+
8688
const stream = await this.client.chat.completions.create(params)
8789

8890
for await (const chunk of stream) {

0 commit comments

Comments
 (0)