Skip to content

Commit 0e9c2d9

Browse files
committed
fix(openai): ensure non-compatible providers send correct max_tokens field
DeepSeek and Mistral are not fully OpenAI-compatible but delegated to CompatibleProvider.Completion(), which unconditionally mapped MaxTokens to max_completion_tokens on the wire. Both APIs expect max_tokens. Add ChatCompletionRequestTransform hook to CompatibleConfig (Strategy pattern) that lets providers adjust the SDK request after convertParams builds it. DeepSeek and Mistral supply transforms that swap max_completion_tokens back to max_tokens and clear unsupported fields. Move Mistral's user/reasoning_effort stripping from preprocessParams (CompletionParams level) to transformRequest (SDK request level) where it correctly prevents the fields from being serialized. Add FakeCompletionServer test helper and wire-level tests that capture actual JSON request bodies to assert correct field names on the wire.
1 parent abb7730 commit 0e9c2d9

File tree

7 files changed

+262
-63
lines changed

7 files changed

+262
-63
lines changed

internal/testutil/fakeserver.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Package testutil provides testing utilities and fixtures for any-llm.
2+
package testutil
3+
4+
import (
5+
"encoding/json"
6+
"io"
7+
"net/http"
8+
"net/http/httptest"
9+
"testing"
10+
)
11+
12+
// FakeCompletionServer creates an httptest server that captures the raw JSON
13+
// request body and returns a minimal valid OpenAI-compatible chat completion
14+
// response. The captured body is returned so callers can assert on the exact
15+
// JSON field names sent over the wire.
16+
func FakeCompletionServer(t *testing.T) (serverURL string, capturedBody func() map[string]any) {
17+
t.Helper()
18+
19+
var body map[string]any
20+
21+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
22+
raw, err := io.ReadAll(r.Body)
23+
if err != nil {
24+
t.Errorf("reading request body: %v", err)
25+
http.Error(w, "bad request", http.StatusBadRequest)
26+
return
27+
}
28+
29+
if err := json.Unmarshal(raw, &body); err != nil {
30+
t.Errorf("unmarshalling request body: %v", err)
31+
http.Error(w, "bad request", http.StatusBadRequest)
32+
return
33+
}
34+
35+
w.Header().Set("Content-Type", "application/json")
36+
// Minimal valid chat completion response.
37+
_, _ = w.Write([]byte(`{
38+
"id": "chatcmpl-test",
39+
"object": "chat.completion",
40+
"created": 1700000000,
41+
"model": "test-model",
42+
"choices": [{
43+
"index": 0,
44+
"message": {"role": "assistant", "content": "hello"},
45+
"finish_reason": "stop"
46+
}],
47+
"usage": {"prompt_tokens": 5, "completion_tokens": 3, "total_tokens": 8}
48+
}`))
49+
}))
50+
51+
t.Cleanup(srv.Close)
52+
53+
return srv.URL, func() map[string]any { return body }
54+
}

providers/deepseek/deepseek.go

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ import (
88
"fmt"
99
"slices"
1010

11+
oaisdk "github.com/openai/openai-go"
12+
"github.com/openai/openai-go/packages/param"
13+
1114
"github.com/mozilla-ai/any-llm-go/config"
1215
"github.com/mozilla-ai/any-llm-go/providers"
1316
"github.com/mozilla-ai/any-llm-go/providers/openai"
@@ -50,13 +53,14 @@ type Provider struct {
5053
// New creates a new DeepSeek provider.
5154
func New(opts ...config.Option) (*Provider, error) {
5255
base, err := openai.NewCompatible(openai.CompatibleConfig{
53-
APIKeyEnvVar: envAPIKey,
54-
BaseURLEnvVar: "",
55-
Capabilities: capabilities(),
56-
DefaultAPIKey: "",
57-
DefaultBaseURL: defaultBaseURL,
58-
Name: providerName,
59-
RequireAPIKey: true,
56+
APIKeyEnvVar: envAPIKey,
57+
BaseURLEnvVar: "",
58+
Capabilities: capabilities(),
59+
DefaultAPIKey: "",
60+
DefaultBaseURL: defaultBaseURL,
61+
Name: providerName,
62+
RequireAPIKey: true,
63+
ChatCompletionRequestTransform: transformRequest,
6064
}, opts...)
6165
if err != nil {
6266
return nil, err
@@ -156,6 +160,19 @@ func preprocessParams(params providers.CompletionParams) providers.CompletionPar
156160
}
157161
}
158162

163+
// transformRequest adjusts the OpenAI SDK request for DeepSeek's API.
164+
// DeepSeek uses max_tokens, not max_completion_tokens.
165+
// See: https://api-docs.deepseek.com/api/create-chat-completion
166+
func transformRequest(req *oaisdk.ChatCompletionNewParams) {
167+
if req.MaxCompletionTokens.Valid() {
168+
// Set max_tokens using max_completion_tokens value.
169+
req.MaxTokens = oaisdk.Int(req.MaxCompletionTokens.Value)
170+
}
171+
172+
// Clear unsupported fields from the request.
173+
req.MaxCompletionTokens = param.Opt[int64]{}
174+
}
175+
159176
// preprocessMessagesForJSONSchema injects the JSON schema into the last user message.
160177
// Returns the modified messages and true if injection succeeded, or the original messages
161178
// and false if injection failed (no user message, non-string content, or marshal error).

providers/deepseek/deepseek_test.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,38 @@ func TestPreprocessMessagesForJSONSchema(t *testing.T) {
354354
})
355355
}
356356

357+
func TestCompletionSendsMaxTokensOnWire(t *testing.T) {
358+
t.Parallel()
359+
360+
serverURL, capturedBody := testutil.FakeCompletionServer(t)
361+
362+
provider, err := New(
363+
config.WithAPIKey("test-key"),
364+
config.WithBaseURL(serverURL),
365+
)
366+
require.NoError(t, err)
367+
368+
maxTokens := 512
369+
params := providers.CompletionParams{
370+
Model: "deepseek-chat",
371+
Messages: testutil.SimpleMessages(),
372+
MaxTokens: &maxTokens,
373+
}
374+
375+
_, err = provider.Completion(context.Background(), params)
376+
require.NoError(t, err)
377+
378+
body := capturedBody()
379+
380+
// DeepSeek is not fully OpenAI-compatible.
381+
// The wire request must use max_tokens (not max_completion_tokens)
382+
// because that is what the DeepSeek API accepts.
383+
// See: https://api-docs.deepseek.com/api/create-chat-completion
384+
require.Contains(t, body, "max_tokens")
385+
require.NotContains(t, body, "max_completion_tokens")
386+
require.Equal(t, float64(512), body["max_tokens"])
387+
}
388+
357389
// Integration tests - only run if DeepSeek API key is available.
358390

359391
func TestIntegrationCompletion(t *testing.T) {

providers/mistral/mistral.go

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ import (
66
"context"
77
"slices"
88

9+
oaisdk "github.com/openai/openai-go"
10+
"github.com/openai/openai-go/packages/param"
11+
912
"github.com/mozilla-ai/any-llm-go/config"
1013
"github.com/mozilla-ai/any-llm-go/providers"
1114
"github.com/mozilla-ai/any-llm-go/providers/openai"
@@ -48,13 +51,14 @@ type Provider struct {
4851
// New creates a new Mistral provider.
4952
func New(opts ...config.Option) (*Provider, error) {
5053
base, err := openai.NewCompatible(openai.CompatibleConfig{
51-
APIKeyEnvVar: envAPIKey,
52-
BaseURLEnvVar: "",
53-
Capabilities: capabilities(),
54-
DefaultAPIKey: "",
55-
DefaultBaseURL: defaultBaseURL,
56-
Name: providerName,
57-
RequireAPIKey: true,
54+
APIKeyEnvVar: envAPIKey,
55+
BaseURLEnvVar: "",
56+
Capabilities: capabilities(),
57+
DefaultAPIKey: "",
58+
DefaultBaseURL: defaultBaseURL,
59+
Name: providerName,
60+
RequireAPIKey: true,
61+
ChatCompletionRequestTransform: transformRequest,
5862
}, opts...)
5963
if err != nil {
6064
return nil, err
@@ -131,11 +135,23 @@ func patchMessages(messages []providers.Message) []providers.Message {
131135
}
132136

133137
// preprocessParams handles Mistral's API requirements.
134-
// Mistral doesn't accept the "user" or "reasoning_effort" fields and requires
135-
// an assistant message between tool results and user messages.
138+
// Mistral requires an assistant message between tool results and user messages.
136139
func preprocessParams(params providers.CompletionParams) providers.CompletionParams {
137140
params.Messages = patchMessages(slices.Clone(params.Messages))
138-
params.ReasoningEffort = "" // Mistral doesn't support reasoning_effort; Magistral models reason automatically.
139-
params.User = "" // Mistral doesn't support the user field.
140141
return params
141142
}
143+
144+
// transformRequest adjusts the OpenAI SDK request for Mistral's API.
145+
// Mistral uses max_tokens (not max_completion_tokens) and does not accept user or reasoning_effort fields.
146+
// See: https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post
147+
func transformRequest(req *oaisdk.ChatCompletionNewParams) {
148+
if req.MaxCompletionTokens.Valid() {
149+
// Set max_tokens using max_completion_tokens value.
150+
req.MaxTokens = oaisdk.Int(req.MaxCompletionTokens.Value)
151+
}
152+
153+
// Clear unsupported fields from the request.
154+
req.MaxCompletionTokens = param.Opt[int64]{}
155+
req.User = param.Opt[string]{}
156+
req.ReasoningEffort = ""
157+
}

providers/mistral/mistral_test.go

Lines changed: 84 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -80,51 +80,6 @@ func TestProviderName(t *testing.T) {
8080
func TestPreprocessParams(t *testing.T) {
8181
t.Parallel()
8282

83-
t.Run("strips user field from params", func(t *testing.T) {
84-
t.Parallel()
85-
86-
params := providers.CompletionParams{
87-
Model: "mistral-small-latest",
88-
Messages: testutil.SimpleMessages(),
89-
User: "test-user",
90-
}
91-
92-
result := preprocessParams(params)
93-
94-
require.Equal(t, params.Model, result.Model)
95-
require.Empty(t, result.User)
96-
})
97-
98-
t.Run("strips reasoning effort from params", func(t *testing.T) {
99-
t.Parallel()
100-
101-
params := providers.CompletionParams{
102-
Model: "magistral-small-latest",
103-
Messages: testutil.SimpleMessages(),
104-
ReasoningEffort: providers.ReasoningEffortLow,
105-
}
106-
107-
result := preprocessParams(params)
108-
109-
require.Equal(t, params.Model, result.Model)
110-
require.Empty(t, result.ReasoningEffort)
111-
})
112-
113-
t.Run("passes through params without user field", func(t *testing.T) {
114-
t.Parallel()
115-
116-
params := providers.CompletionParams{
117-
Model: "mistral-small-latest",
118-
Messages: testutil.SimpleMessages(),
119-
}
120-
121-
result := preprocessParams(params)
122-
123-
require.Equal(t, params.Model, result.Model)
124-
require.Equal(t, len(params.Messages), len(result.Messages))
125-
require.Empty(t, result.User)
126-
})
127-
12883
t.Run("patches messages with tool-to-user sequence", func(t *testing.T) {
12984
t.Parallel()
13085

@@ -276,6 +231,90 @@ func TestPatchMessages(t *testing.T) {
276231
})
277232
}
278233

234+
func TestCompletionSendsMaxTokensOnWire(t *testing.T) {
235+
t.Parallel()
236+
237+
serverURL, capturedBody := testutil.FakeCompletionServer(t)
238+
239+
provider, err := New(
240+
config.WithAPIKey("test-key"),
241+
config.WithBaseURL(serverURL),
242+
)
243+
require.NoError(t, err)
244+
245+
maxTokens := 256
246+
params := providers.CompletionParams{
247+
Model: "mistral-small-latest",
248+
Messages: testutil.SimpleMessages(),
249+
MaxTokens: &maxTokens,
250+
}
251+
252+
_, err = provider.Completion(context.Background(), params)
253+
require.NoError(t, err)
254+
255+
body := capturedBody()
256+
257+
// Mistral is not fully OpenAI-compatible.
258+
// The wire request must use max_tokens (not max_completion_tokens)
259+
// because that is what the Mistral API accepts.
260+
// See: https://docs.mistral.ai/api?property=operation-chat_completion_v1_chat_completions_post_request_max_tokens
261+
require.Contains(t, body, "max_tokens")
262+
require.NotContains(t, body, "max_completion_tokens")
263+
require.Equal(t, float64(256), body["max_tokens"])
264+
}
265+
266+
func TestCompletionStripsUserField(t *testing.T) {
267+
t.Parallel()
268+
269+
serverURL, capturedBody := testutil.FakeCompletionServer(t)
270+
271+
provider, err := New(
272+
config.WithAPIKey("test-key"),
273+
config.WithBaseURL(serverURL),
274+
)
275+
require.NoError(t, err)
276+
277+
params := providers.CompletionParams{
278+
Model: "mistral-small-latest",
279+
Messages: testutil.SimpleMessages(),
280+
User: "test-user",
281+
}
282+
283+
_, err = provider.Completion(context.Background(), params)
284+
require.NoError(t, err)
285+
286+
body := capturedBody()
287+
288+
// Mistral doesn't support the user field; it must not appear on the wire.
289+
require.NotContains(t, body, "user")
290+
}
291+
292+
func TestCompletionStripsReasoningEffort(t *testing.T) {
293+
t.Parallel()
294+
295+
serverURL, capturedBody := testutil.FakeCompletionServer(t)
296+
297+
provider, err := New(
298+
config.WithAPIKey("test-key"),
299+
config.WithBaseURL(serverURL),
300+
)
301+
require.NoError(t, err)
302+
303+
params := providers.CompletionParams{
304+
Model: "magistral-small-latest",
305+
Messages: testutil.SimpleMessages(),
306+
ReasoningEffort: providers.ReasoningEffortHigh,
307+
}
308+
309+
_, err = provider.Completion(context.Background(), params)
310+
require.NoError(t, err)
311+
312+
body := capturedBody()
313+
314+
// Mistral doesn't support reasoning_effort; it must not appear on the wire.
315+
require.NotContains(t, body, "reasoning_effort")
316+
}
317+
279318
// Integration tests - only run if Mistral API key is available.
280319

281320
func TestIntegrationCompletion(t *testing.T) {

providers/openai/compatible.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ type CompatibleConfig struct {
7070

7171
// RequireAPIKey indicates whether an API key is required.
7272
RequireAPIKey bool
73+
74+
// ChatCompletionRequestTransform is an optional function that modifies the chat completion
75+
// request after construction. Providers that are not fully OpenAI-compatible use this to
76+
// adjust wire-level fields (e.g. swapping max_completion_tokens back to max_tokens).
77+
// Nil means no transformation.
78+
ChatCompletionRequestTransform func(*openai.ChatCompletionNewParams)
7379
}
7480

7581
// Ensure CompatibleProvider implements the required interfaces.
@@ -143,6 +149,9 @@ func (p *CompatibleProvider) Completion(
143149
}
144150

145151
req := convertParams(params)
152+
if p.compatibleConfig.ChatCompletionRequestTransform != nil {
153+
p.compatibleConfig.ChatCompletionRequestTransform(&req)
154+
}
146155

147156
resp, err := p.client.Chat.Completions.New(ctx, req)
148157
if err != nil {
@@ -170,6 +179,9 @@ func (p *CompatibleProvider) CompletionStream(
170179
}
171180

172181
req := convertParams(params)
182+
if p.compatibleConfig.ChatCompletionRequestTransform != nil {
183+
p.compatibleConfig.ChatCompletionRequestTransform(&req)
184+
}
173185
stream := p.client.Chat.Completions.NewStreaming(ctx, req)
174186

175187
for stream.Next() {

0 commit comments

Comments
 (0)