diff --git a/cmd/cli/commands/compose.go b/cmd/cli/commands/compose.go index 029761a4..5f3169a5 100644 --- a/cmd/cli/commands/compose.go +++ b/cmd/cli/commands/compose.go @@ -34,6 +34,19 @@ func newComposeCmd() *cobra.Command { return c } +// Reasoning budget constants for the think parameter conversion +const ( + reasoningBudgetUnlimited int32 = -1 + reasoningBudgetDisabled int32 = 0 + reasoningBudgetMedium int32 = 1024 + reasoningBudgetLow int32 = 256 +) + +// ptr is a helper function to create a pointer to int32 +func ptr(v int32) *int32 { + return &v +} + func newUpCommand() *cobra.Command { var models []string var ctxSize int64 @@ -41,6 +54,8 @@ func newUpCommand() *cobra.Command { var draftModel string var numTokens int var minAcceptanceRate float64 + var mode string + var think string c := &cobra.Command{ Use: "up", RunE: func(cmd *cobra.Command, args []string) error { @@ -81,6 +96,30 @@ func newUpCommand() *cobra.Command { sendInfo(fmt.Sprintf("Enabling speculative decoding with draft model: %s", draftModel)) } + // Parse mode if provided + var backendMode *inference.BackendMode + if mode != "" { + parsedMode, err := parseBackendMode(mode) + if err != nil { + _ = sendError(err.Error()) + return err + } + backendMode = &parsedMode + sendInfo(fmt.Sprintf("Setting backend mode to %s", mode)) + } + + // Parse think parameter for reasoning budget + var reasoningBudget *int32 + if think != "" { + budget, err := parseThinkToReasoningBudget(think) + if err != nil { + _ = sendError(err.Error()) + return err + } + reasoningBudget = budget + sendInfo(fmt.Sprintf("Setting think mode to %s", think)) + } + for _, model := range models { configuration := inference.BackendConfiguration{ Speculative: speculativeConfig, @@ -91,8 +130,17 @@ func newUpCommand() *cobra.Command { configuration.ContextSize = &v } + // Set llama.cpp-specific reasoning budget if provided + if reasoningBudget != nil { + if configuration.LlamaCpp == nil { + configuration.LlamaCpp = &inference.LlamaCppConfig{} + } + configuration.LlamaCpp.ReasoningBudget = reasoningBudget + } + if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{ Model: model, + Mode: backendMode, BackendConfiguration: configuration, }); err != nil { configErrFmtString := "failed to configure backend for model %s with context-size %d" @@ -123,10 +171,57 @@ func newUpCommand() *cobra.Command { c.Flags().StringVar(&draftModel, "speculative-draft-model", "", "draft model for speculative decoding") c.Flags().IntVar(&numTokens, "speculative-num-tokens", 0, "number of tokens to predict speculatively") c.Flags().Float64Var(&minAcceptanceRate, "speculative-min-acceptance-rate", 0, "minimum acceptance rate for speculative decoding") + c.Flags().StringVar(&mode, "mode", "", "backend operation mode (completion, embedding, reranking)") + c.Flags().StringVar(&think, "think", "", "enable reasoning mode for thinking models (true/false/high/medium/low)") _ = c.MarkFlagRequired("model") return c } +// parseBackendMode parses a string mode value into an inference.BackendMode. +func parseBackendMode(mode string) (inference.BackendMode, error) { + switch strings.ToLower(mode) { + case "completion": + return inference.BackendModeCompletion, nil + case "embedding": + return inference.BackendModeEmbedding, nil + case "reranking": + return inference.BackendModeReranking, nil + default: + return inference.BackendModeCompletion, fmt.Errorf("invalid mode %q: must be one of completion, embedding, reranking", mode) + } +} + +// parseThinkToReasoningBudget converts the think parameter string to a reasoning budget value. +// Accepts: "true", "false", "high", "medium", "low" +// Returns: +// - nil for empty string or "true" (use server default, which is unlimited) +// - -1 for "high" (explicitly set unlimited) +// - 0 for "false" (disable thinking) +// - 1024 for "medium" +// - 256 for "low" +func parseThinkToReasoningBudget(think string) (*int32, error) { + if think == "" { + return nil, nil + } + + switch strings.ToLower(think) { + case "true": + // Use nil to let the server use its default (currently unlimited) + return nil, nil + case "high": + // Explicitly set unlimited reasoning budget + return ptr(reasoningBudgetUnlimited), nil + case "false": + return ptr(reasoningBudgetDisabled), nil + case "medium": + return ptr(reasoningBudgetMedium), nil + case "low": + return ptr(reasoningBudgetLow), nil + default: + return nil, fmt.Errorf("invalid think value %q: must be one of true, false, high, medium, low", think) + } +} + func newDownCommand() *cobra.Command { c := &cobra.Command{ Use: "down", diff --git a/cmd/cli/commands/compose_test.go b/cmd/cli/commands/compose_test.go new file mode 100644 index 00000000..5d2de01b --- /dev/null +++ b/cmd/cli/commands/compose_test.go @@ -0,0 +1,154 @@ +package commands + +import ( + "testing" + + "github.com/docker/model-runner/pkg/inference" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseBackendMode(t *testing.T) { + tests := []struct { + name string + input string + expected inference.BackendMode + expectError bool + }{ + { + name: "completion mode lowercase", + input: "completion", + expected: inference.BackendModeCompletion, + expectError: false, + }, + { + name: "completion mode uppercase", + input: "COMPLETION", + expected: inference.BackendModeCompletion, + expectError: false, + }, + { + name: "completion mode mixed case", + input: "Completion", + expected: inference.BackendModeCompletion, + expectError: false, + }, + { + name: "embedding mode", + input: "embedding", + expected: inference.BackendModeEmbedding, + expectError: false, + }, + { + name: "reranking mode", + input: "reranking", + expected: inference.BackendModeReranking, + expectError: false, + }, + { + name: "invalid mode", + input: "invalid", + expected: inference.BackendModeCompletion, // default on error + expectError: true, + }, + { + name: "empty string", + input: "", + expected: inference.BackendModeCompletion, // default on error + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseBackendMode(tt.input) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestParseThinkToReasoningBudget(t *testing.T) { + tests := []struct { + name string + input string + expected *int32 + expectError bool + }{ + { + name: "empty string returns nil", + input: "", + expected: nil, + expectError: false, + }, + { + name: "true returns nil (use server default)", + input: "true", + expected: nil, + expectError: false, + }, + { + name: "TRUE returns nil (case insensitive)", + input: "TRUE", + expected: nil, + expectError: false, + }, + { + name: "false disables reasoning", + input: "false", + expected: ptr(reasoningBudgetDisabled), + expectError: false, + }, + { + name: "high explicitly sets unlimited (-1)", + input: "high", + expected: ptr(reasoningBudgetUnlimited), + expectError: false, + }, + { + name: "medium sets 1024 tokens", + input: "medium", + expected: ptr(reasoningBudgetMedium), + expectError: false, + }, + { + name: "low sets 256 tokens", + input: "low", + expected: ptr(reasoningBudgetLow), + expectError: false, + }, + { + name: "invalid value returns error", + input: "invalid", + expected: nil, + expectError: true, + }, + { + name: "numeric string returns error", + input: "1024", + expected: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := parseThinkToReasoningBudget(tt.input) + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + if tt.expected == nil { + assert.Nil(t, result) + } else { + require.NotNil(t, result) + assert.Equal(t, *tt.expected, *result) + } + } + }) + } +} diff --git a/cmd/cli/docs/reference/docker_model_compose_up.yaml b/cmd/cli/docs/reference/docker_model_compose_up.yaml index 9a0bf1b3..70f72e77 100644 --- a/cmd/cli/docs/reference/docker_model_compose_up.yaml +++ b/cmd/cli/docs/reference/docker_model_compose_up.yaml @@ -23,6 +23,15 @@ options: experimentalcli: false kubernetes: false swarm: false + - option: mode + value_type: string + description: backend operation mode (completion, embedding, reranking) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false - option: model value_type: stringArray default_value: '[]' @@ -62,6 +71,16 @@ options: experimentalcli: false kubernetes: false swarm: false + - option: think + value_type: string + description: | + enable reasoning mode for thinking models (true/false/high/medium/low) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false inherited_options: - option: project-name value_type: string