Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 7 additions & 107 deletions cmd/cli/commands/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,13 @@ func newComposeCmd() *cobra.Command {
return c
}

// Reasoning budget constants for the think parameter conversion
const (
reasoningBudgetUnlimited int32 = -1
reasoningBudgetDisabled int32 = 0
reasoningBudgetMedium int32 = 1024
reasoningBudgetLow int32 = 256
)

// ptr is a helper function to create a pointer to int32
func ptr(v int32) *int32 {
return &v
}

func newUpCommand() *cobra.Command {
var models []string
var ctxSize int64
var backend string
var draftModel string
var numTokens int
var minAcceptanceRate float64
var mode string
var think string
c := &cobra.Command{
Use: "up",
RunE: func(cmd *cobra.Command, args []string) error {
Expand All @@ -81,7 +66,7 @@ func newUpCommand() *cobra.Command {
return err
}

if cmd.Flags().Changed("context-size") {
if ctxSize > 0 {
sendInfo(fmt.Sprintf("Setting context size to %d", ctxSize))
}

Expand All @@ -96,52 +81,14 @@ func newUpCommand() *cobra.Command {
sendInfo(fmt.Sprintf("Enabling speculative decoding with draft model: %s", draftModel))
}

// Parse mode if provided
var backendMode *inference.BackendMode
if mode != "" {
parsedMode, err := parseBackendMode(mode)
if err != nil {
_ = sendError(err.Error())
return err
}
backendMode = &parsedMode
sendInfo(fmt.Sprintf("Setting backend mode to %s", mode))
}

// Parse think parameter for reasoning budget
var reasoningBudget *int32
if think != "" {
budget, err := parseThinkToReasoningBudget(think)
if err != nil {
_ = sendError(err.Error())
return err
}
reasoningBudget = budget
sendInfo(fmt.Sprintf("Setting think mode to %s", think))
}

for _, model := range models {
configuration := inference.BackendConfiguration{
Speculative: speculativeConfig,
}
if cmd.Flags().Changed("context-size") {
// TODO is the context size the same for all models?
v := int32(ctxSize)
configuration.ContextSize = &v
}

// Set llama.cpp-specific reasoning budget if provided
if reasoningBudget != nil {
if configuration.LlamaCpp == nil {
configuration.LlamaCpp = &inference.LlamaCppConfig{}
}
configuration.LlamaCpp.ReasoningBudget = reasoningBudget
}

size := int32(ctxSize)
if err := desktopClient.ConfigureBackend(scheduling.ConfigureRequest{
Model: model,
Mode: backendMode,
BackendConfiguration: configuration,
Model: model,
BackendConfiguration: inference.BackendConfiguration{
ContextSize: &size,
Speculative: speculativeConfig,
},
}); err != nil {
configErrFmtString := "failed to configure backend for model %s with context-size %d"
_ = sendErrorf(configErrFmtString+": %v", model, ctxSize, err)
Expand Down Expand Up @@ -171,57 +118,10 @@ func newUpCommand() *cobra.Command {
c.Flags().StringVar(&draftModel, "speculative-draft-model", "", "draft model for speculative decoding")
c.Flags().IntVar(&numTokens, "speculative-num-tokens", 0, "number of tokens to predict speculatively")
c.Flags().Float64Var(&minAcceptanceRate, "speculative-min-acceptance-rate", 0, "minimum acceptance rate for speculative decoding")
c.Flags().StringVar(&mode, "mode", "", "backend operation mode (completion, embedding, reranking)")
c.Flags().StringVar(&think, "think", "", "enable reasoning mode for thinking models (true/false/high/medium/low)")
_ = c.MarkFlagRequired("model")
return c
}

// parseBackendMode parses a string mode value into an inference.BackendMode.
func parseBackendMode(mode string) (inference.BackendMode, error) {
switch strings.ToLower(mode) {
case "completion":
return inference.BackendModeCompletion, nil
case "embedding":
return inference.BackendModeEmbedding, nil
case "reranking":
return inference.BackendModeReranking, nil
default:
return inference.BackendModeCompletion, fmt.Errorf("invalid mode %q: must be one of completion, embedding, reranking", mode)
}
}

// parseThinkToReasoningBudget converts the think parameter string to a reasoning budget value.
// Accepts: "true", "false", "high", "medium", "low"
// Returns:
// - nil for empty string or "true" (use server default, which is unlimited)
// - -1 for "high" (explicitly set unlimited)
// - 0 for "false" (disable thinking)
// - 1024 for "medium"
// - 256 for "low"
func parseThinkToReasoningBudget(think string) (*int32, error) {
if think == "" {
return nil, nil
}

switch strings.ToLower(think) {
case "true":
// Use nil to let the server use its default (currently unlimited)
return nil, nil
case "high":
// Explicitly set unlimited reasoning budget
return ptr(reasoningBudgetUnlimited), nil
case "false":
return ptr(reasoningBudgetDisabled), nil
case "medium":
return ptr(reasoningBudgetMedium), nil
case "low":
return ptr(reasoningBudgetLow), nil
default:
return nil, fmt.Errorf("invalid think value %q: must be one of true, false, high, medium, low", think)
}
}

func newDownCommand() *cobra.Command {
c := &cobra.Command{
Use: "down",
Expand Down
81 changes: 0 additions & 81 deletions cmd/cli/commands/compose_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,84 +71,3 @@ func TestParseBackendMode(t *testing.T) {
})
}
}

func TestParseThinkToReasoningBudget(t *testing.T) {
tests := []struct {
name string
input string
expected *int32
expectError bool
}{
{
name: "empty string returns nil",
input: "",
expected: nil,
expectError: false,
},
{
name: "true returns nil (use server default)",
input: "true",
expected: nil,
expectError: false,
},
{
name: "TRUE returns nil (case insensitive)",
input: "TRUE",
expected: nil,
expectError: false,
},
{
name: "false disables reasoning",
input: "false",
expected: ptr(reasoningBudgetDisabled),
expectError: false,
},
{
name: "high explicitly sets unlimited (-1)",
input: "high",
expected: ptr(reasoningBudgetUnlimited),
expectError: false,
},
{
name: "medium sets 1024 tokens",
input: "medium",
expected: ptr(reasoningBudgetMedium),
expectError: false,
},
{
name: "low sets 256 tokens",
input: "low",
expected: ptr(reasoningBudgetLow),
expectError: false,
},
{
name: "invalid value returns error",
input: "invalid",
expected: nil,
expectError: true,
},
{
name: "numeric string returns error",
input: "1024",
expected: nil,
expectError: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseThinkToReasoningBudget(tt.input)
if tt.expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
if tt.expected == nil {
assert.Nil(t, result)
} else {
require.NotNil(t, result)
assert.Equal(t, *tt.expected, *result)
}
}
})
}
}
90 changes: 7 additions & 83 deletions cmd/cli/commands/configure.go
Original file line number Diff line number Diff line change
@@ -1,59 +1,17 @@
package commands

import (
"encoding/json"
"fmt"
"strconv"

"github.com/docker/model-runner/cmd/cli/commands/completion"
"github.com/docker/model-runner/pkg/inference"

"github.com/docker/model-runner/pkg/inference/scheduling"
"github.com/spf13/cobra"
)

// Int32PtrValue implements pflag.Value interface for *int32 pointers
// This allows flags to have a nil default value instead of 0
type Int32PtrValue struct {
ptr **int32
}

func NewInt32PtrValue(p **int32) *Int32PtrValue {
return &Int32PtrValue{ptr: p}
}

func (v *Int32PtrValue) String() string {
if v.ptr == nil || *v.ptr == nil {
return ""
}
return strconv.FormatInt(int64(**v.ptr), 10)
}

func (v *Int32PtrValue) Set(s string) error {
val, err := strconv.ParseInt(s, 10, 32)
if err != nil {
return err
}
i32 := int32(val)
*v.ptr = &i32
return nil
}

func (v *Int32PtrValue) Type() string {
return "int32"
}

func newConfigureCmd() *cobra.Command {
var opts scheduling.ConfigureRequest
var draftModel string
var numTokens int
var minAcceptanceRate float64
var hfOverrides string
var contextSize *int32
var reasoningBudget *int32
var flags ConfigureFlags

c := &cobra.Command{
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--reasoning-budget=<n>] MODEL",
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--mode=<mode>] [--think] MODEL",
Short: "Configure runtime options for a model",
Hidden: true,
Args: func(cmd *cobra.Command, args []string) error {
Expand All @@ -63,53 +21,19 @@ func newConfigureCmd() *cobra.Command {
"See 'docker model configure --help' for more information",
len(args), args)
}
opts.Model = args[0]
return nil
},
RunE: func(cmd *cobra.Command, args []string) error {
// contextSize is nil by default, only set if user provided the flag
opts.ContextSize = contextSize
// Build the speculative config if any speculative flags are set
if draftModel != "" || numTokens > 0 || minAcceptanceRate > 0 {
opts.Speculative = &inference.SpeculativeDecodingConfig{
DraftModel: draftModel,
NumTokens: numTokens,
MinAcceptanceRate: minAcceptanceRate,
}
}
// Parse and validate HuggingFace overrides if provided (vLLM-specific)
if hfOverrides != "" {
var hfo inference.HFOverrides
if err := json.Unmarshal([]byte(hfOverrides), &hfo); err != nil {
return fmt.Errorf("invalid --hf_overrides JSON: %w", err)
}
// Validate the overrides to prevent command injection
if err := hfo.Validate(); err != nil {
return err
}
if opts.VLLM == nil {
opts.VLLM = &inference.VLLMConfig{}
}
opts.VLLM.HFOverrides = hfo
}
// Set llama.cpp-specific reasoning budget if provided
// reasoningBudget is nil by default, only set if user provided the flag
if reasoningBudget != nil {
if opts.LlamaCpp == nil {
opts.LlamaCpp = &inference.LlamaCppConfig{}
}
opts.LlamaCpp.ReasoningBudget = reasoningBudget
model := args[0]
opts, err := flags.BuildConfigureRequest(model)
if err != nil {
return err
}
return desktopClient.ConfigureBackend(opts)
},
ValidArgsFunction: completion.ModelNames(getDesktopClient, -1),
}

c.Flags().Var(NewInt32PtrValue(&contextSize), "context-size", "context size (in tokens)")
c.Flags().StringVar(&draftModel, "speculative-draft-model", "", "draft model for speculative decoding")
c.Flags().IntVar(&numTokens, "speculative-num-tokens", 0, "number of tokens to predict speculatively")
c.Flags().Float64Var(&minAcceptanceRate, "speculative-min-acceptance-rate", 0, "minimum acceptance rate for speculative decoding")
c.Flags().StringVar(&hfOverrides, "hf_overrides", "", "HuggingFace model config overrides (JSON) - vLLM only")
c.Flags().Var(NewInt32PtrValue(&reasoningBudget), "reasoning-budget", "reasoning budget for reasoning models - llama.cpp only")
flags.RegisterFlags(c)
return c
}
Loading
Loading