Skip to content

Commit 4b59f4f

Browse files
committed
fix(openai): ensure non-compatible providers send correct max_tokens field
DeepSeek and Mistral are not fully OpenAI-compatible but delegated to CompatibleProvider.Completion(), which unconditionally mapped MaxTokens to max_completion_tokens on the wire. Both APIs expect max_tokens. Add ChatCompletionRequestTransform hook to CompatibleConfig (Strategy pattern) that lets providers adjust the SDK request after convertParams builds it. DeepSeek and Mistral supply transforms that swap max_completion_tokens back to max_tokens and clear unsupported fields. Move Mistral's user/reasoning_effort stripping from preprocessParams (CompletionParams level) to transformRequest (SDK request level) where it correctly prevents the fields from being serialized. Add FakeCompletionServer test helper and wire-level tests that capture actual JSON request bodies to assert correct field names on the wire.
1 parent abb7730 commit 4b59f4f

File tree

7 files changed

+437
-70
lines changed

7 files changed

+437
-70
lines changed

internal/testutil/fakeserver.go

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
// Package testutil provides testing utilities and fixtures for any-llm.
2+
package testutil
3+
4+
import (
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"net/http/httptest"
10+
"sync"
11+
"testing"
12+
)
13+
14+
// FakeCompletionServer creates an httptest server that captures the raw JSON
15+
// request body and returns a minimal valid OpenAI-compatible chat completion
16+
// response. The captured body is returned so callers can assert on the exact
17+
// JSON field names sent over the wire.
18+
func FakeCompletionServer(t *testing.T) (serverURL string, capturedBody func() map[string]any) {
19+
t.Helper()
20+
21+
var (
22+
mu sync.Mutex
23+
body map[string]any
24+
)
25+
26+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
27+
raw, err := io.ReadAll(r.Body)
28+
if err != nil {
29+
t.Errorf("reading request body: %v", err)
30+
http.Error(w, "bad request", http.StatusBadRequest)
31+
return
32+
}
33+
34+
mu.Lock()
35+
if err := json.Unmarshal(raw, &body); err != nil {
36+
mu.Unlock()
37+
t.Errorf("unmarshalling request body: %v", err)
38+
http.Error(w, "bad request", http.StatusBadRequest)
39+
return
40+
}
41+
mu.Unlock()
42+
43+
w.Header().Set("Content-Type", "application/json")
44+
// Minimal valid chat completion response.
45+
_, _ = w.Write([]byte(`{
46+
"id": "chatcmpl-test",
47+
"object": "chat.completion",
48+
"created": 1700000000,
49+
"model": "test-model",
50+
"choices": [{
51+
"index": 0,
52+
"message": {"role": "assistant", "content": "hello"},
53+
"finish_reason": "stop"
54+
}],
55+
"usage": {"prompt_tokens": 5, "completion_tokens": 3, "total_tokens": 8}
56+
}`))
57+
}))
58+
59+
t.Cleanup(srv.Close)
60+
61+
return srv.URL, func() map[string]any {
62+
mu.Lock()
63+
defer mu.Unlock()
64+
return body
65+
}
66+
}
67+
68+
// FakeStreamingServer creates an httptest server that captures the raw JSON
69+
// request body and returns a minimal valid OpenAI-compatible streaming (SSE)
70+
// response. The captured body is returned so callers can assert on the exact
71+
// JSON field names sent over the wire.
72+
func FakeStreamingServer(t *testing.T) (serverURL string, capturedBody func() map[string]any) {
73+
t.Helper()
74+
75+
var (
76+
mu sync.Mutex
77+
body map[string]any
78+
)
79+
80+
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
81+
raw, err := io.ReadAll(r.Body)
82+
if err != nil {
83+
t.Errorf("reading request body: %v", err)
84+
http.Error(w, "bad request", http.StatusBadRequest)
85+
return
86+
}
87+
88+
mu.Lock()
89+
if err := json.Unmarshal(raw, &body); err != nil {
90+
mu.Unlock()
91+
t.Errorf("unmarshalling request body: %v", err)
92+
http.Error(w, "bad request", http.StatusBadRequest)
93+
return
94+
}
95+
mu.Unlock()
96+
97+
w.Header().Set("Content-Type", "text/event-stream")
98+
w.Header().Set("Cache-Control", "no-cache")
99+
w.Header().Set("Connection", "keep-alive")
100+
101+
// Minimal valid SSE streaming response.
102+
chunk := `{"id":"chatcmpl-test","object":"chat.completion.chunk","created":1700000000,"model":"test-model","choices":[{"index":0,"delta":{"role":"assistant","content":"hello"},"finish_reason":null}]}`
103+
done := `{"id":"chatcmpl-test","object":"chat.completion.chunk","created":1700000000,"model":"test-model","choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}`
104+
105+
_, _ = fmt.Fprintf(w, "data: %s\n\n", chunk)
106+
_, _ = fmt.Fprintf(w, "data: %s\n\n", done)
107+
_, _ = fmt.Fprint(w, "data: [DONE]\n\n")
108+
}))
109+
110+
t.Cleanup(srv.Close)
111+
112+
return srv.URL, func() map[string]any {
113+
mu.Lock()
114+
defer mu.Unlock()
115+
return body
116+
}
117+
}

providers/deepseek/deepseek.go

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ import (
88
"fmt"
99
"slices"
1010

11+
oaisdk "github.com/openai/openai-go"
12+
"github.com/openai/openai-go/packages/param"
13+
1114
"github.com/mozilla-ai/any-llm-go/config"
1215
"github.com/mozilla-ai/any-llm-go/providers"
1316
"github.com/mozilla-ai/any-llm-go/providers/openai"
@@ -50,13 +53,14 @@ type Provider struct {
5053
// New creates a new DeepSeek provider.
5154
func New(opts ...config.Option) (*Provider, error) {
5255
base, err := openai.NewCompatible(openai.CompatibleConfig{
53-
APIKeyEnvVar: envAPIKey,
54-
BaseURLEnvVar: "",
55-
Capabilities: capabilities(),
56-
DefaultAPIKey: "",
57-
DefaultBaseURL: defaultBaseURL,
58-
Name: providerName,
59-
RequireAPIKey: true,
56+
APIKeyEnvVar: envAPIKey,
57+
BaseURLEnvVar: "",
58+
Capabilities: capabilities(),
59+
ChatCompletionRequestTransform: transformRequest,
60+
DefaultAPIKey: "",
61+
DefaultBaseURL: defaultBaseURL,
62+
Name: providerName,
63+
RequireAPIKey: true,
6064
}, opts...)
6165
if err != nil {
6266
return nil, err
@@ -156,6 +160,19 @@ func preprocessParams(params providers.CompletionParams) providers.CompletionPar
156160
}
157161
}
158162

163+
// transformRequest adjusts the OpenAI SDK request for DeepSeek's API.
164+
// DeepSeek uses max_tokens, not max_completion_tokens.
165+
// See: https://api-docs.deepseek.com/api/create-chat-completion
166+
func transformRequest(req *oaisdk.ChatCompletionNewParams) {
167+
if req.MaxCompletionTokens.Valid() {
168+
// Set max_tokens using max_completion_tokens value.
169+
req.MaxTokens = oaisdk.Int(req.MaxCompletionTokens.Value)
170+
}
171+
172+
// Clear unsupported fields from the request.
173+
req.MaxCompletionTokens = param.Opt[int64]{}
174+
}
175+
159176
// preprocessMessagesForJSONSchema injects the JSON schema into the last user message.
160177
// Returns the modified messages and true if injection succeeded, or the original messages
161178
// and false if injection failed (no user message, non-string content, or marshal error).

providers/deepseek/deepseek_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,73 @@ func TestPreprocessMessagesForJSONSchema(t *testing.T) {
354354
})
355355
}
356356

357+
func TestCompletionSendsMaxTokensOnWire(t *testing.T) {
358+
t.Parallel()
359+
360+
serverURL, capturedBody := testutil.FakeCompletionServer(t)
361+
362+
provider, err := New(
363+
config.WithAPIKey("test-key"),
364+
config.WithBaseURL(serverURL),
365+
)
366+
require.NoError(t, err)
367+
368+
maxTokens := 512
369+
params := providers.CompletionParams{
370+
Model: "deepseek-chat",
371+
Messages: testutil.SimpleMessages(),
372+
MaxTokens: &maxTokens,
373+
}
374+
375+
_, err = provider.Completion(context.Background(), params)
376+
require.NoError(t, err)
377+
378+
body := capturedBody()
379+
380+
// DeepSeek is not fully OpenAI-compatible.
381+
// The wire request must use max_tokens (not max_completion_tokens)
382+
// because that is what the DeepSeek API accepts.
383+
// See: https://api-docs.deepseek.com/api/create-chat-completion
384+
require.Contains(t, body, "max_tokens")
385+
require.NotContains(t, body, "max_completion_tokens")
386+
require.Equal(t, float64(512), body["max_tokens"])
387+
}
388+
389+
func TestCompletionStreamSendsMaxTokensOnWire(t *testing.T) {
390+
t.Parallel()
391+
392+
serverURL, capturedBody := testutil.FakeStreamingServer(t)
393+
394+
provider, err := New(
395+
config.WithAPIKey("test-key"),
396+
config.WithBaseURL(serverURL),
397+
)
398+
require.NoError(t, err)
399+
400+
maxTokens := 512
401+
params := providers.CompletionParams{
402+
Model: "deepseek-chat",
403+
Messages: testutil.SimpleMessages(),
404+
MaxTokens: &maxTokens,
405+
Stream: true,
406+
}
407+
408+
chunks, errs := provider.CompletionStream(context.Background(), params)
409+
for range chunks {
410+
// Drain the channel.
411+
}
412+
require.NoError(t, <-errs)
413+
414+
body := capturedBody()
415+
416+
// DeepSeek is not fully OpenAI-compatible.
417+
// The streaming wire request must also use max_tokens (not max_completion_tokens).
418+
// See: https://api-docs.deepseek.com/api/create-chat-completion
419+
require.Contains(t, body, "max_tokens")
420+
require.NotContains(t, body, "max_completion_tokens")
421+
require.Equal(t, float64(512), body["max_tokens"])
422+
}
423+
357424
// Integration tests - only run if DeepSeek API key is available.
358425

359426
func TestIntegrationCompletion(t *testing.T) {

providers/mistral/mistral.go

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ import (
66
"context"
77
"slices"
88

9+
oaisdk "github.com/openai/openai-go"
10+
"github.com/openai/openai-go/packages/param"
11+
912
"github.com/mozilla-ai/any-llm-go/config"
1013
"github.com/mozilla-ai/any-llm-go/providers"
1114
"github.com/mozilla-ai/any-llm-go/providers/openai"
@@ -48,13 +51,14 @@ type Provider struct {
4851
// New creates a new Mistral provider.
4952
func New(opts ...config.Option) (*Provider, error) {
5053
base, err := openai.NewCompatible(openai.CompatibleConfig{
51-
APIKeyEnvVar: envAPIKey,
52-
BaseURLEnvVar: "",
53-
Capabilities: capabilities(),
54-
DefaultAPIKey: "",
55-
DefaultBaseURL: defaultBaseURL,
56-
Name: providerName,
57-
RequireAPIKey: true,
54+
APIKeyEnvVar: envAPIKey,
55+
BaseURLEnvVar: "",
56+
Capabilities: capabilities(),
57+
ChatCompletionRequestTransform: transformRequest,
58+
DefaultAPIKey: "",
59+
DefaultBaseURL: defaultBaseURL,
60+
Name: providerName,
61+
RequireAPIKey: true,
5862
}, opts...)
5963
if err != nil {
6064
return nil, err
@@ -69,7 +73,7 @@ func (p *Provider) Completion(
6973
ctx context.Context,
7074
params providers.CompletionParams,
7175
) (*providers.ChatCompletion, error) {
72-
params = preprocessParams(params)
76+
params = patchMessageParams(params)
7377
return p.CompatibleProvider.Completion(ctx, params)
7478
}
7579

@@ -79,7 +83,7 @@ func (p *Provider) CompletionStream(
7983
ctx context.Context,
8084
params providers.CompletionParams,
8185
) (<-chan providers.ChatCompletionChunk, <-chan error) {
82-
params = preprocessParams(params)
86+
params = patchMessageParams(params)
8387
return p.CompatibleProvider.CompletionStream(ctx, params)
8488
}
8589

@@ -130,12 +134,24 @@ func patchMessages(messages []providers.Message) []providers.Message {
130134
return result
131135
}
132136

133-
// preprocessParams handles Mistral's API requirements.
134-
// Mistral doesn't accept the "user" or "reasoning_effort" fields and requires
135-
// an assistant message between tool results and user messages.
136-
func preprocessParams(params providers.CompletionParams) providers.CompletionParams {
137+
// patchMessageParams handles Mistral's message-level requirements.
138+
// Mistral requires an assistant message between tool results and user messages.
139+
func patchMessageParams(params providers.CompletionParams) providers.CompletionParams {
137140
params.Messages = patchMessages(slices.Clone(params.Messages))
138-
params.ReasoningEffort = "" // Mistral doesn't support reasoning_effort; Magistral models reason automatically.
139-
params.User = "" // Mistral doesn't support the user field.
140141
return params
141142
}
143+
144+
// transformRequest adjusts the OpenAI SDK request for Mistral's API.
145+
// Mistral uses max_tokens (not max_completion_tokens) and does not accept user or reasoning_effort fields.
146+
// See: https://docs.mistral.ai/api/#tag/chat/operation/chat_completion_v1_chat_completions_post
147+
func transformRequest(req *oaisdk.ChatCompletionNewParams) {
148+
if req.MaxCompletionTokens.Valid() {
149+
// Set max_tokens using max_completion_tokens value.
150+
req.MaxTokens = oaisdk.Int(req.MaxCompletionTokens.Value)
151+
}
152+
153+
// Clear unsupported fields from the request.
154+
req.MaxCompletionTokens = param.Opt[int64]{}
155+
req.User = param.Opt[string]{}
156+
req.ReasoningEffort = ""
157+
}

0 commit comments

Comments
 (0)