Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions cmd/cli/commands/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func newUpCommand() *cobra.Command {
return err
}

if ctxSize > 0 {
if cmd.Flags().Changed("context-size") {
sendInfo(fmt.Sprintf("Setting context size to %d", ctxSize))
}

Expand All @@ -82,12 +82,18 @@ func newUpCommand() *cobra.Command {
}

for _, model := range models {
configuration := inference.BackendConfiguration{
Speculative: speculativeConfig,
}
if cmd.Flags().Changed("context-size") {
// TODO is the context size the same for all models?
v := int32(ctxSize)
configuration.ContextSize = &v
}

if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{
Model: model,
BackendConfiguration: inference.BackendConfiguration{
ContextSize: ctxSize,
Speculative: speculativeConfig,
},
Model: model,
BackendConfiguration: configuration,
}); err != nil {
configErrFmtString := "failed to configure backend for model %s with context-size %d"
_ = sendErrorf(configErrFmtString+": %v", model, ctxSize, err)
Expand Down
50 changes: 42 additions & 8 deletions cmd/cli/commands/configure.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package commands
import (
"encoding/json"
"fmt"
"strconv"

"github.com/docker/model-runner/cmd/cli/commands/completion"
"github.com/docker/model-runner/pkg/inference"
Expand All @@ -11,13 +12,45 @@ import (
"github.com/spf13/cobra"
)

// Int32PtrValue implements pflag.Value interface for *int32 pointers
// This allows flags to have a nil default value instead of 0
type Int32PtrValue struct {
ptr **int32
}

func NewInt32PtrValue(p **int32) *Int32PtrValue {
return &Int32PtrValue{ptr: p}
}

func (v *Int32PtrValue) String() string {
if v.ptr == nil || *v.ptr == nil {
return ""
}
return strconv.FormatInt(int64(**v.ptr), 10)
}

func (v *Int32PtrValue) Set(s string) error {
val, err := strconv.ParseInt(s, 10, 32)
if err != nil {
return err
}
i32 := int32(val)
*v.ptr = &i32
return nil
}

func (v *Int32PtrValue) Type() string {
return "int32"
}

func newConfigureCmd() *cobra.Command {
var opts scheduling.ConfigureRequest
var draftModel string
var numTokens int
var minAcceptanceRate float64
var hfOverrides string
var reasoningBudget int64
var contextSize *int32
var reasoningBudget *int32

c := &cobra.Command{
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--reasoning-budget=<n>] MODEL",
Expand All @@ -34,6 +67,8 @@ func newConfigureCmd() *cobra.Command {
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
// contextSize is nil by default, only set if user provided the flag
opts.ContextSize = contextSize
// Build the speculative config if any speculative flags are set
if draftModel != "" || numTokens > 0 || minAcceptanceRate > 0 {
opts.Speculative = &inference.SpeculativeDecodingConfig{
Expand All @@ -57,25 +92,24 @@ func newConfigureCmd() *cobra.Command {
}
opts.VLLM.HFOverrides = hfo
}
// Set llama.cpp-specific reasoning budget if explicitly provided
// Note: We check if flag was changed rather than checking value > 0
// because 0 is a valid value (disables reasoning) and -1 means unlimited
if cmd.Flags().Changed("reasoning-budget") {
// Set llama.cpp-specific reasoning budget if provided
// reasoningBudget is nil by default, only set if user provided the flag
if reasoningBudget != nil {
if opts.LlamaCpp == nil {
opts.LlamaCpp = &inference.LlamaCppConfig{}
}
opts.LlamaCpp.ReasoningBudget = &reasoningBudget
opts.LlamaCpp.ReasoningBudget = reasoningBudget
}
return desktopClient.ConfigureBackend(opts)
},
ValidArgsFunction: completion.ModelNames(getDesktopClient, -1),
}

c.Flags().Int64Var(&opts.ContextSize, "context-size", -1, "context size (in tokens)")
c.Flags().Var(NewInt32PtrValue(&contextSize), "context-size", "context size (in tokens)")
c.Flags().StringVar(&draftModel, "speculative-draft-model", "", "draft model for speculative decoding")
c.Flags().IntVar(&numTokens, "speculative-num-tokens", 0, "number of tokens to predict speculatively")
c.Flags().Float64Var(&minAcceptanceRate, "speculative-min-acceptance-rate", 0, "minimum acceptance rate for speculative decoding")
c.Flags().StringVar(&hfOverrides, "hf_overrides", "", "HuggingFace model config overrides (JSON) - vLLM only")
c.Flags().Int64Var(&reasoningBudget, "reasoning-budget", 0, "reasoning budget for reasoning models - llama.cpp only")
c.Flags().Var(NewInt32PtrValue(&reasoningBudget), "reasoning-budget", "reasoning budget for reasoning models - llama.cpp only")
return c
}
48 changes: 21 additions & 27 deletions cmd/cli/commands/configure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ func TestConfigureCmdReasoningBudgetFlag(t *testing.T) {
t.Fatal("--reasoning-budget flag not found")
}

// Verify the default value is 0
if reasoningBudgetFlag.DefValue != "0" {
t.Errorf("Expected default reasoning-budget value to be '0', got '%s'", reasoningBudgetFlag.DefValue)
// Verify the default value is empty (nil pointer)
if reasoningBudgetFlag.DefValue != "" {
t.Errorf("Expected default reasoning-budget value to be '' (nil), got '%s'", reasoningBudgetFlag.DefValue)
}

// Verify the flag type
if reasoningBudgetFlag.Value.Type() != "int64" {
t.Errorf("Expected reasoning-budget flag type to be 'int64', got '%s'", reasoningBudgetFlag.Value.Type())
if reasoningBudgetFlag.Value.Type() != "int32" {
t.Errorf("Expected reasoning-budget flag type to be 'int32', got '%s'", reasoningBudgetFlag.Value.Type())
}
}

Expand All @@ -30,31 +30,31 @@ func TestConfigureCmdReasoningBudgetFlagChanged(t *testing.T) {
name string
setValue string
expectChanged bool
expectedValue int64
expectedValue string
}{
{
name: "flag not set - should not be changed",
setValue: "",
expectChanged: false,
expectedValue: 0,
expectedValue: "",
},
{
name: "flag set to 0 (disable reasoning) - should be changed",
setValue: "0",
expectChanged: true,
expectedValue: 0,
expectedValue: "0",
},
{
name: "flag set to -1 (unlimited) - should be changed",
setValue: "-1",
expectChanged: true,
expectedValue: -1,
expectedValue: "-1",
},
{
name: "flag set to positive value - should be changed",
setValue: "1024",
expectChanged: true,
expectedValue: 1024,
expectedValue: "1024",
},
}

Expand All @@ -77,13 +77,11 @@ func TestConfigureCmdReasoningBudgetFlagChanged(t *testing.T) {
t.Errorf("Expected Changed() = %v, got %v", tt.expectChanged, isChanged)
}

// Verify the value
value, err := cmd.Flags().GetInt64("reasoning-budget")
if err != nil {
t.Fatalf("Failed to get reasoning-budget flag value: %v", err)
}
// Verify the value using String() method
flag := cmd.Flags().Lookup("reasoning-budget")
value := flag.Value.String()
if value != tt.expectedValue {
t.Errorf("Expected value = %d, got %d", tt.expectedValue, value)
t.Errorf("Expected value = %s, got %s", tt.expectedValue, value)
}
})
}
Expand Down Expand Up @@ -120,9 +118,9 @@ func TestConfigureCmdContextSizeFlag(t *testing.T) {
t.Fatal("--context-size flag not found")
}

// Verify the default value is -1 (indicating not set)
if contextSizeFlag.DefValue != "-1" {
t.Errorf("Expected default context-size value to be '-1', got '%s'", contextSizeFlag.DefValue)
// Verify the default value is empty (nil pointer)
if contextSizeFlag.DefValue != "" {
t.Errorf("Expected default context-size value to be '' (nil), got '%s'", contextSizeFlag.DefValue)
}

// Test setting the flag value
Expand All @@ -131,14 +129,10 @@ func TestConfigureCmdContextSizeFlag(t *testing.T) {
t.Errorf("Failed to set context-size flag: %v", err)
}

// Verify the value was set
contextSizeValue, err := cmd.Flags().GetInt64("context-size")
if err != nil {
t.Errorf("Failed to get context-size flag value: %v", err)
}

if contextSizeValue != 8192 {
t.Errorf("Expected context-size flag value to be 8192, got %d", contextSizeValue)
// Verify the value was set using String() method
contextSizeValue := contextSizeFlag.Value.String()
if contextSizeValue != "8192" {
t.Errorf("Expected context-size flag value to be '8192', got '%s'", contextSizeValue)
}
}

Expand Down
26 changes: 15 additions & 11 deletions cmd/cli/commands/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func verifyModelInspect(t *testing.T, client *desktop.Client, ref, expectedID, e

// createAndPushTestModel creates a minimal test model and pushes it to the local registry.
// Returns the model ID, FQDNs for host and network access, and the manifest digest.
func createAndPushTestModel(t *testing.T, registryURL, modelRef string, contextSize uint64) (modelID, hostFQDN, networkFQDN, digest string) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add an integration test path where contextSize is nil to ensure default behavior is preserved

Right now all integration call sites pass a non‑nil *int32 to createAndPushTestModel, so the nil path isn’t exercised. Please add a test that uses contextSize == nil and confirms the model can still be pushed/pulled/inspected and that no unexpected context-size metadata is introduced, preserving the previous default behavior after the switch to *int32.

func createAndPushTestModel(t *testing.T, registryURL, modelRef string, contextSize *int32) (modelID, hostFQDN, networkFQDN, digest string) {
ctx := context.Background()

// Use the dummy GGUF file from assets
Expand All @@ -234,8 +234,8 @@ func createAndPushTestModel(t *testing.T, registryURL, modelRef string, contextS
require.NoError(t, err)

// Set context size if specified
if contextSize > 0 {
pkg = pkg.WithContextSize(contextSize)
if contextSize != nil {
pkg = pkg.WithContextSize(*contextSize)
}

// Construct the full reference with the local registry host for pushing from test host
Expand Down Expand Up @@ -287,7 +287,7 @@ func TestIntegration_PullModel(t *testing.T) {
// Create and push two test models with different organizations
// Model 1: custom org (test/test-model:latest)
modelRef1 := "test/test-model:latest"
modelID1, hostFQDN1, networkFQDN1, digest1 := createAndPushTestModel(t, env.registryURL, modelRef1, 2048)
modelID1, hostFQDN1, networkFQDN1, digest1 := createAndPushTestModel(t, env.registryURL, modelRef1, int32ptr(2048))
t.Logf("Test model 1 pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN1, modelID1, networkFQDN1, digest1)

// Generate test cases for custom org model (test/test-model)
Expand All @@ -304,7 +304,7 @@ func TestIntegration_PullModel(t *testing.T) {

// Model 2: default org (ai/test-model:latest)
modelRef2 := "ai/test-model:latest"
modelID2, hostFQDN2, networkFQDN2, digest2 := createAndPushTestModel(t, env.registryURL, modelRef2, 2048)
modelID2, hostFQDN2, networkFQDN2, digest2 := createAndPushTestModel(t, env.registryURL, modelRef2, int32ptr(2048))
t.Logf("Test model 2 pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN2, modelID2, networkFQDN2, digest2)

// Generate test cases for default org model (ai/test-model)
Expand Down Expand Up @@ -420,7 +420,7 @@ func TestIntegration_InspectModel(t *testing.T) {

// Create and push a test model with default org (ai/inspect-test:latest)
modelRef := "ai/inspect-test:latest"
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, 2048)
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, int32ptr(2048))
t.Logf("Test model pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN, modelID, networkFQDN, digest)

// Pull the model using a short reference
Expand Down Expand Up @@ -479,7 +479,7 @@ func TestIntegration_TagModel(t *testing.T) {

// Create and push a test model with default org (ai/tag-test:latest)
modelRef := "ai/tag-test:latest"
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, 2048)
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, int32ptr(2048))
t.Logf("Test model pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN, modelID, networkFQDN, digest)

// Pull the model using a simple reference
Expand Down Expand Up @@ -657,7 +657,7 @@ func TestIntegration_PushModel(t *testing.T) {

// Create and push a test model with default org (ai/tag-test:latest)
modelRef := "ai/tag-test:latest"
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, 2048)
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, int32ptr(2048))
t.Logf("Test model pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN, modelID, networkFQDN, digest)

// Pull the model using a simple reference
Expand Down Expand Up @@ -791,7 +791,7 @@ func TestIntegration_RemoveModel(t *testing.T) {

// Create and push a test model with default org (ai/rm-test:latest)
modelRef := "ai/rm-test:latest"
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, 2048)
modelID, hostFQDN, networkFQDN, digest := createAndPushTestModel(t, env.registryURL, modelRef, int32ptr(2048))
t.Logf("Test model pushed: %s (ID: %s) FQDN: %s Digest: %s", hostFQDN, modelID, networkFQDN, digest)

// Generate all reference test cases
Expand Down Expand Up @@ -842,9 +842,9 @@ func TestIntegration_RemoveModel(t *testing.T) {
t.Run("remove multiple models", func(t *testing.T) {
// Create and push two different models
modelRef1 := "ai/rm-multi-1:latest"
modelID1, _, _, _ := createAndPushTestModel(t, env.registryURL, modelRef1, 2048)
modelID1, _, _, _ := createAndPushTestModel(t, env.registryURL, modelRef1, int32ptr(2048))
modelRef2 := "ai/rm-multi-2:latest"
modelID2, _, _, _ := createAndPushTestModel(t, env.registryURL, modelRef2, 2048)
modelID2, _, _, _ := createAndPushTestModel(t, env.registryURL, modelRef2, int32ptr(2048))

// Pull both models
t.Logf("Pulling first model: rm-multi-1")
Expand Down Expand Up @@ -1014,3 +1014,7 @@ func TestIntegration_RemoveModel(t *testing.T) {
})
})
}

func int32ptr(n int32) *int32 {
return &n
}
4 changes: 2 additions & 2 deletions cmd/cli/commands/package.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,9 @@ func packageModel(cmd *cobra.Command, opts packageOptions) error {
distClient := initResult.distClient

// Set context size
if opts.contextSize > 0 {
if cmd.Flags().Changed("context-size") {
cmd.PrintErrf("Setting context size %d\n", opts.contextSize)
pkg = pkg.WithContextSize(opts.contextSize)
pkg = pkg.WithContextSize(int32(opts.contextSize))
}
Comment on lines +287 to 290
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The conversion from uint64 to int32 for contextSize could lead to an integer overflow if a user provides a value larger than math.MaxInt32. While unlikely for a context size, adding a validation check would make the code more robust.

if cmd.Flags().Changed("context-size") {
		if opts.contextSize > 2147483647 { // math.MaxInt32
			return fmt.Errorf("context size %d is too large, must be less than or equal to 2147483647", opts.contextSize)
		}
		cmd.PrintErrf("Setting context size %d\n", opts.contextSize)
		pkg = pkg.WithContextSize(int32(opts.contextSize))
	}


// Add license files
Expand Down
6 changes: 2 additions & 4 deletions cmd/cli/docs/reference/docker_model_configure.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ pname: docker model
plink: docker_model.yaml
options:
- option: context-size
value_type: int64
default_value: "-1"
value_type: int32
description: context size (in tokens)
deprecated: false
hidden: false
Expand All @@ -25,8 +24,7 @@ options:
kubernetes: false
swarm: false
- option: reasoning-budget
value_type: int64
default_value: "0"
value_type: int32
description: reasoning budget for reasoning models - llama.cpp only
deprecated: false
hidden: false
Expand Down
2 changes: 1 addition & 1 deletion cmd/mdltool/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ func cmdPackage(args []string) int {

if contextSize > 0 {
fmt.Println("Setting context size:", contextSize)
b = b.WithContextSize(contextSize)
b = b.WithContextSize(int32(contextSize))
}
Comment on lines 322 to 325
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to another file, the conversion from uint64 to int32 for contextSize could overflow. It's best to add a check to ensure the value is within the valid range for an int32.

if contextSize > 0 {
		if contextSize > 2147483647 { // math.MaxInt32
			fmt.Fprintf(os.Stderr, "context size %d is too large, must be less than or equal to 2147483647\n", contextSize)
			return 1
		}
		fmt.Println("Setting context size:", contextSize)
		b = b.WithContextSize(int32(contextSize))
	}


if mmproj != "" {
Expand Down
2 changes: 1 addition & 1 deletion pkg/distribution/builder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (b *Builder) WithLicense(path string) (*Builder, error) {
}, nil
}

func (b *Builder) WithContextSize(size uint64) *Builder {
func (b *Builder) WithContextSize(size int32) *Builder {
return &Builder{
model: mutate.ContextSize(b.model, size),
originalLayers: b.originalLayers,
Expand Down
2 changes: 1 addition & 1 deletion pkg/distribution/internal/mutate/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type model struct {
base types.ModelArtifact
appended []v1.Layer
configMediaType ggcr.MediaType
contextSize *uint64
contextSize *int32
}

func (m *model) Descriptor() (types.Descriptor, error) {
Expand Down
2 changes: 1 addition & 1 deletion pkg/distribution/internal/mutate/mutate.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func ConfigMediaType(mdl types.ModelArtifact, mt ggcr.MediaType) types.ModelArti
}
}

func ContextSize(mdl types.ModelArtifact, cs uint64) types.ModelArtifact {
func ContextSize(mdl types.ModelArtifact, cs int32) types.ModelArtifact {
return &model{
base: mdl,
contextSize: &cs,
Expand Down
Loading
Loading