Skip to content

Commit 7beb678

Browse files
authored
Merge pull request #1391 from krissetto/better-dmr-support
Better DMR support
2 parents de1b4d3 + 1909cbc commit 7beb678

File tree

7 files changed

+668
-98
lines changed

7 files changed

+668
-98
lines changed

docs/PROVIDERS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,6 @@ The DMR provider supports speculative decoding for faster inference. Configure i
134134
- `speculative_num_tokens` (int): Number of tokens to generate speculatively
135135
- `speculative_acceptance_rate` (float): Acceptance rate threshold for speculative tokens
136136

137-
All three options are passed to `docker model configure` as command-line flags.
137+
All three options are sent to Model Runner via its internal `POST /engines/_configure` API endpoint.
138138

139139
You can also pass any flag of the underlying model runtime (llama.cpp or vllm) using the `runtime_flags` option

docs/USAGE.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -722,12 +722,12 @@ models:
722722
speculative_acceptance_rate: 0.8 # Acceptance rate threshold
723723
```
724724

725-
All three speculative decoding options are passed to `docker model configure` as flags:
726-
- `speculative_draft_model` → `--speculative-draft-model`
727-
- `speculative_num_tokens` → `--speculative-num-tokens`
728-
- `speculative_acceptance_rate` → `--speculative-acceptance-rate`
725+
All three speculative decoding options are sent to Model Runner via its internal `POST /engines/_configure` API endpoint:
726+
- `speculative_draft_model` → `speculative.draft_model`
727+
- `speculative_num_tokens` → `speculative.num_tokens`
728+
- `speculative_acceptance_rate` → `speculative.min_acceptance_rate`
729729

730-
These options work alongside `max_tokens` (which sets `--context-size`) and `runtime_flags`.
730+
These options work alongside `max_tokens` (which sets `context-size`) and `runtime_flags`.
731731

732732
##### Troubleshooting:
733733

pkg/fake/proxy.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,14 @@ func RemoveHeadersHook(i *cassette.Interaction) error {
206206
return nil
207207
}
208208

209-
// DefaultMatcher creates a matcher that normalizes tool call IDs for consistent matching.
209+
// DefaultMatcher creates a matcher that normalizes dynamic fields for consistent matching.
210210
// The onError callback is called if reading the request body fails (nil logs and returns false).
211211
func DefaultMatcher(onError func(err error)) recorder.MatcherFunc {
212+
// Normalize tool call IDs (they change between requests)
212213
callIDRegex := regexp.MustCompile(`call_[a-z0-9\-]+`)
214+
// Normalize max_tokens/max_output_tokens/maxOutputTokens field (varies based on models.dev
215+
// cache state and provider cloning behavior). Handles both snake_case and camelCase variants.
216+
maxTokensRegex := regexp.MustCompile(`"(?:max_(?:output_)?tokens|maxOutputTokens)":\d+,?`)
213217

214218
return func(r *http.Request, i cassette.Request) bool {
215219
if r.Body == nil || r.Body == http.NoBody {
@@ -234,8 +238,13 @@ func DefaultMatcher(onError func(err error)) recorder.MatcherFunc {
234238
r.Body.Close()
235239
r.Body = io.NopCloser(bytes.NewBuffer(reqBody))
236240

237-
// Normalize tool call IDs for matching
238-
return callIDRegex.ReplaceAllString(string(reqBody), "call_ID") == callIDRegex.ReplaceAllString(i.Body, "call_ID")
241+
// Normalize dynamic fields for matching
242+
normalizedReq := callIDRegex.ReplaceAllString(string(reqBody), "call_ID")
243+
normalizedReq = maxTokensRegex.ReplaceAllString(normalizedReq, "")
244+
normalizedCassette := callIDRegex.ReplaceAllString(i.Body, "call_ID")
245+
normalizedCassette = maxTokensRegex.ReplaceAllString(normalizedCassette, "")
246+
247+
return normalizedReq == normalizedCassette
239248
}
240249
}
241250

pkg/model/provider/clone.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ func CloneWithOptions(ctx context.Context, base Provider, opts ...options.Opt) P
1919

2020
// Apply max_tokens override if present in options
2121
// We need to apply it to the ModelConfig itself since that's what providers use
22+
// Only update MaxTokens if an option explicitly sets it (non-zero value)
2223
modelConfig := config.ModelConfig
2324
for _, opt := range mergedOpts {
2425
tempOpts := &options.ModelOptions{}
2526
opt(tempOpts)
26-
mt := tempOpts.MaxTokens()
27-
modelConfig.MaxTokens = &mt
27+
if mt := tempOpts.MaxTokens(); mt != 0 {
28+
modelConfig.MaxTokens = &mt
29+
}
2830
}
2931

3032
// Use NewWithModels to support cloning routers that reference other models.

pkg/model/provider/clone_test.go

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,85 @@ func TestCloneWithOptions_DirectProvider(t *testing.T) {
137137
assert.Nil(t, clonedConfig.ModelConfig.ThinkingBudget,
138138
"ThinkingBudget should be nil after cloning with WithThinking(false)")
139139
}
140+
141+
func TestCloneWithOptions_PreservesMaxTokens(t *testing.T) {
142+
t.Parallel()
143+
144+
// This test verifies that max_tokens is preserved when cloning a provider
145+
// with options that don't explicitly set max_tokens. Previously, options
146+
// that didn't set max_tokens would accidentally clear it to 0.
147+
148+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
149+
w.Header().Set("Content-Type", "text/event-stream")
150+
_, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n"))
151+
_, _ = w.Write([]byte("data: [DONE]\n\n"))
152+
}))
153+
defer server.Close()
154+
155+
maxTokens := int64(8192)
156+
cfg := &latest.ModelConfig{
157+
Provider: "openai",
158+
Model: "gpt-4o",
159+
BaseURL: server.URL,
160+
MaxTokens: &maxTokens,
161+
}
162+
163+
env := newCloneTestEnv(map[string]string{
164+
"OPENAI_API_KEY": "test-key",
165+
})
166+
167+
provider, err := New(t.Context(), cfg, env, options.WithMaxTokens(maxTokens))
168+
require.NoError(t, err)
169+
170+
// Clone with an option that doesn't affect max_tokens (e.g., WithThinking)
171+
cloned := CloneWithOptions(t.Context(), provider, options.WithThinking(false))
172+
173+
clonedConfig := cloned.BaseConfig()
174+
175+
// MaxTokens should be preserved, not cleared to 0 or nil
176+
require.NotNil(t, clonedConfig.ModelConfig.MaxTokens,
177+
"MaxTokens should be preserved after cloning with unrelated options")
178+
assert.Equal(t, maxTokens, *clonedConfig.ModelConfig.MaxTokens,
179+
"MaxTokens value should be unchanged after cloning")
180+
}
181+
182+
func TestCloneWithOptions_OverridesMaxTokens(t *testing.T) {
183+
t.Parallel()
184+
185+
// This test verifies that max_tokens can be explicitly overridden when cloning.
186+
187+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
188+
w.Header().Set("Content-Type", "text/event-stream")
189+
_, _ = w.Write([]byte("data: {\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\n"))
190+
_, _ = w.Write([]byte("data: [DONE]\n\n"))
191+
}))
192+
defer server.Close()
193+
194+
originalMaxTokens := int64(8192)
195+
newMaxTokens := int64(4096)
196+
197+
cfg := &latest.ModelConfig{
198+
Provider: "openai",
199+
Model: "gpt-4o",
200+
BaseURL: server.URL,
201+
MaxTokens: &originalMaxTokens,
202+
}
203+
204+
env := newCloneTestEnv(map[string]string{
205+
"OPENAI_API_KEY": "test-key",
206+
})
207+
208+
provider, err := New(t.Context(), cfg, env, options.WithMaxTokens(originalMaxTokens))
209+
require.NoError(t, err)
210+
211+
// Clone with an explicit max_tokens override
212+
cloned := CloneWithOptions(t.Context(), provider, options.WithMaxTokens(newMaxTokens))
213+
214+
clonedConfig := cloned.BaseConfig()
215+
216+
// MaxTokens should be updated to the new value
217+
require.NotNil(t, clonedConfig.ModelConfig.MaxTokens,
218+
"MaxTokens should not be nil after cloning with explicit override")
219+
assert.Equal(t, newMaxTokens, *clonedConfig.ModelConfig.MaxTokens,
220+
"MaxTokens should be updated to the new value")
221+
}

0 commit comments

Comments
 (0)