diff --git a/cmd/cli/commands/compose.go b/cmd/cli/commands/compose.go index 7aa3ffac..d71ebef4 100644 --- a/cmd/cli/commands/compose.go +++ b/cmd/cli/commands/compose.go @@ -37,6 +37,7 @@ func newComposeCmd() *cobra.Command { func newUpCommand() *cobra.Command { var models []string var ctxSize int64 + var rawRuntimeFlags string var backend string var draftModel string var numTokens int @@ -69,6 +70,9 @@ func newUpCommand() *cobra.Command { if ctxSize > 0 { sendInfo(fmt.Sprintf("Setting context size to %d", ctxSize)) } + if rawRuntimeFlags != "" { + sendInfo("Setting raw runtime flags to " + rawRuntimeFlags) + } // Build speculative config if any speculative flags are set var speculativeConfig *inference.SpeculativeDecodingConfig @@ -89,10 +93,11 @@ func newUpCommand() *cobra.Command { ContextSize: &size, Speculative: speculativeConfig, }, + RawRuntimeFlags: rawRuntimeFlags, }); err != nil { - configErrFmtString := "failed to configure backend for model %s with context-size %d" - _ = sendErrorf(configErrFmtString+": %v", model, ctxSize, err) - return fmt.Errorf(configErrFmtString+": %w", model, ctxSize, err) + configErrFmtString := "failed to configure backend for model %s with context-size %d and runtime-flags %s" + _ = sendErrorf(configErrFmtString+": %v", model, ctxSize, rawRuntimeFlags, err) + return fmt.Errorf(configErrFmtString+": %w", model, ctxSize, rawRuntimeFlags, err) } sendInfo("Successfully configured backend for model " + model) } @@ -114,6 +119,7 @@ func newUpCommand() *cobra.Command { } c.Flags().StringArrayVar(&models, "model", nil, "model to use") c.Flags().Int64Var(&ctxSize, "context-size", -1, "context size for the model") + c.Flags().StringVar(&rawRuntimeFlags, "runtime-flags", "", "raw runtime flags to pass to the inference engine") c.Flags().StringVar(&backend, "backend", llamacpp.Name, "inference backend to use") c.Flags().StringVar(&draftModel, "speculative-draft-model", "", "draft model for speculative decoding") c.Flags().IntVar(&numTokens, "speculative-num-tokens", 0, "number of tokens to predict speculatively") diff --git a/cmd/cli/commands/configure.go b/cmd/cli/commands/configure.go index b42910ba..9679e3b0 100644 --- a/cmd/cli/commands/configure.go +++ b/cmd/cli/commands/configure.go @@ -11,15 +11,27 @@ func newConfigureCmd() *cobra.Command { var flags ConfigureFlags c := &cobra.Command{ - Use: "configure [--context-size=] [--speculative-draft-model=] [--hf_overrides=] [--gpu-memory-utilization=] [--mode=] [--think] MODEL", + Use: "configure [--context-size=] [--speculative-draft-model=] [--hf_overrides=] [--gpu-memory-utilization=] [--mode=] [--think] MODEL [-- ]", Short: "Configure runtime options for a model", Hidden: true, Args: func(cmd *cobra.Command, args []string) error { - if len(args) != 1 { - return fmt.Errorf( - "Exactly one model must be specified, got %d: %v\n\n"+ - "See 'docker model configure --help' for more information", - len(args), args) + argsBeforeDash := cmd.ArgsLenAtDash() + if argsBeforeDash == -1 { + // No "--" used, so we need exactly 1 total argument. + if len(args) != 1 { + return fmt.Errorf( + "Exactly one model must be specified, got %d: %v\n\n"+ + "See 'docker model configure --help' for more information", + len(args), args) + } + } else { + // Has "--", so we need exactly 1 argument before it. + if argsBeforeDash != 1 { + return fmt.Errorf( + "Exactly one model must be specified before --, got %d\n\n"+ + "See 'docker model configure --help' for more information", + argsBeforeDash) + } } return nil }, @@ -29,6 +41,7 @@ func newConfigureCmd() *cobra.Command { if err != nil { return err } + opts.RuntimeFlags = args[1:] return desktopClient.ConfigureBackend(opts) }, ValidArgsFunction: completion.ModelNames(getDesktopClient, -1), diff --git a/cmd/cli/commands/configure_test.go b/cmd/cli/commands/configure_test.go index 40a705a1..66180b37 100644 --- a/cmd/cli/commands/configure_test.go +++ b/cmd/cli/commands/configure_test.go @@ -305,3 +305,90 @@ func TestThinkFlagBehavior(t *testing.T) { }) } } + +func TestRuntimeFlagsValidation(t *testing.T) { + tests := []struct { + name string + runtimeFlags []string + expectError bool + errorContains string + }{ + { + name: "valid runtime flags without paths", + runtimeFlags: []string{"--verbose", "--threads", "4"}, + expectError: false, + }, + { + name: "empty runtime flags", + runtimeFlags: []string{}, + expectError: false, + }, + { + name: "reject absolute path in value", + runtimeFlags: []string{"--log-file", "/var/log/model.log"}, + expectError: true, + errorContains: "paths are not allowed", + }, + { + name: "reject absolute path in flag=value format", + runtimeFlags: []string{"--output-file=/tmp/output.txt"}, + expectError: true, + errorContains: "paths are not allowed", + }, + { + name: "reject relative path", + runtimeFlags: []string{"--config", "../config.yaml"}, + expectError: true, + errorContains: "paths are not allowed", + }, + { + name: "reject URL", + runtimeFlags: []string{"--endpoint", "http://example.com/api"}, + expectError: true, + errorContains: "paths are not allowed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + flags := ConfigureFlags{} + req, err := flags.BuildConfigureRequest("test-model") + if err != nil { + t.Fatalf("BuildConfigureRequest failed: %v", err) + } + + // Set runtime flags after building request + req.RuntimeFlags = tt.runtimeFlags + + // Note: The actual validation happens in scheduler.ConfigureRunner, + // but we're testing that the BuildConfigureRequest correctly + // preserves the RuntimeFlags for validation downstream. + // For a true integration test, we would need to mock the scheduler. + + if tt.expectError { + // In this unit test context, we verify the flags are preserved + // The actual validation will happen in the scheduler + if len(req.RuntimeFlags) == 0 && len(tt.runtimeFlags) > 0 { + t.Error("RuntimeFlags should be preserved in the request") + } + } else { + if !equalStringSlices(req.RuntimeFlags, tt.runtimeFlags) { + t.Errorf("Expected RuntimeFlags %v, got %v", tt.runtimeFlags, req.RuntimeFlags) + } + } + }) + } +} + +// equalStringSlices checks if two string slices are equal +func equalStringSlices(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/cmd/cli/docs/reference/docker_model_compose_up.yaml b/cmd/cli/docs/reference/docker_model_compose_up.yaml index 9a0bf1b3..17e91577 100644 --- a/cmd/cli/docs/reference/docker_model_compose_up.yaml +++ b/cmd/cli/docs/reference/docker_model_compose_up.yaml @@ -33,6 +33,15 @@ options: experimentalcli: false kubernetes: false swarm: false + - option: runtime-flags + value_type: string + description: raw runtime flags to pass to the inference engine + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false - option: speculative-draft-model value_type: string description: draft model for speculative decoding diff --git a/cmd/cli/docs/reference/docker_model_configure.yaml b/cmd/cli/docs/reference/docker_model_configure.yaml index 4f941717..9a9d3e8c 100644 --- a/cmd/cli/docs/reference/docker_model_configure.yaml +++ b/cmd/cli/docs/reference/docker_model_configure.yaml @@ -1,7 +1,7 @@ command: docker model configure short: Configure runtime options for a model long: Configure runtime options for a model -usage: docker model configure [--context-size=] [--speculative-draft-model=] [--hf_overrides=] [--gpu-memory-utilization=] [--mode=] [--think] MODEL +usage: docker model configure [--context-size=] [--speculative-draft-model=] [--hf_overrides=] [--gpu-memory-utilization=] [--mode=] [--think] MODEL [-- ] pname: docker model plink: docker_model.yaml options: diff --git a/go.mod b/go.mod index bc41d458..4d7bcd49 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/gpustack/gguf-parser-go v0.22.1 github.com/jaypipes/ghw v0.19.1 github.com/kolesnikovae/go-winjob v1.0.0 + github.com/mattn/go-shellwords v1.0.12 github.com/opencontainers/go-digest v1.0.0 github.com/opencontainers/image-spec v1.1.1 github.com/prometheus/client_model v0.6.2 diff --git a/go.sum b/go.sum index 0de37587..a36d73ea 100644 --- a/go.sum +++ b/go.sum @@ -80,6 +80,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-shellwords v1.0.12 h1:M2zGm7EW6UQJvDeQxo4T51eKPurbeFbe8WtebGE2xrk= +github.com/mattn/go-shellwords v1.0.12/go.mod h1:EZzvwXDESEeg03EKmM+RmDnNOPKG4lLtQsUlTZDWQ8Y= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg= diff --git a/pkg/inference/backend.go b/pkg/inference/backend.go index 4ab1ad79..2ea64d87 100644 --- a/pkg/inference/backend.go +++ b/pkg/inference/backend.go @@ -82,8 +82,9 @@ type LlamaCppConfig struct { type BackendConfiguration struct { // Shared configuration across all backends - ContextSize *int32 `json:"context-size,omitempty"` - Speculative *SpeculativeDecodingConfig `json:"speculative,omitempty"` + ContextSize *int32 `json:"context-size,omitempty"` + RuntimeFlags []string `json:"runtime-flags,omitempty"` + Speculative *SpeculativeDecodingConfig `json:"speculative,omitempty"` // Backend-specific configuration VLLM *VLLMConfig `json:"vllm,omitempty"` diff --git a/pkg/inference/backends/llamacpp/llamacpp_config.go b/pkg/inference/backends/llamacpp/llamacpp_config.go index 8375eff1..c0fad124 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config.go @@ -79,6 +79,11 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference args = append(args, "--ctx-size", strconv.FormatInt(int64(*contextSize), 10)) } + // Add arguments from backend config + if config != nil { + args = append(args, config.RuntimeFlags...) + } + // Add arguments for Multimodal projector or jinja (they are mutually exclusive) if path := bundle.MMPROJPath(); path != "" { args = append(args, "--mmproj", path) diff --git a/pkg/inference/backends/llamacpp/llamacpp_config_test.go b/pkg/inference/backends/llamacpp/llamacpp_config_test.go index b67939df..d7452356 100644 --- a/pkg/inference/backends/llamacpp/llamacpp_config_test.go +++ b/pkg/inference/backends/llamacpp/llamacpp_config_test.go @@ -225,6 +225,23 @@ func TestGetArgs(t *testing.T) { "--jinja", ), }, + { + name: "raw flags from backend config", + mode: inference.BackendModeEmbedding, + bundle: &fakeBundle{ + ggufPath: modelPath, + }, + config: &inference.BackendConfiguration{ + RuntimeFlags: []string{"--some", "flag"}, + }, + expected: append(slices.Clone(baseArgs), + "--model", modelPath, + "--host", socket, + "--embeddings", + "--some", "flag", + "--jinja", + ), + }, { name: "multimodal projector removes jinja", mode: inference.BackendModeCompletion, diff --git a/pkg/inference/backends/vllm/vllm_config.go b/pkg/inference/backends/vllm/vllm_config.go index 92bb6fb4..b07c2f2e 100644 --- a/pkg/inference/backends/vllm/vllm_config.go +++ b/pkg/inference/backends/vllm/vllm_config.go @@ -56,7 +56,11 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference if maxLen := GetMaxModelLen(bundle.RuntimeConfig(), config); maxLen != nil { args = append(args, "--max-model-len", strconv.FormatInt(int64(*maxLen), 10)) } - // If nil, vLLM will automatically derive from the model config + + // Add runtime flags from backend config + if config != nil { + args = append(args, config.RuntimeFlags...) + } // Add vLLM-specific arguments from backend config if config != nil && config.VLLM != nil { diff --git a/pkg/inference/backends/vllm/vllm_config_test.go b/pkg/inference/backends/vllm/vllm_config_test.go index e4538293..ee9304f9 100644 --- a/pkg/inference/backends/vllm/vllm_config_test.go +++ b/pkg/inference/backends/vllm/vllm_config_test.go @@ -83,6 +83,23 @@ func TestGetArgs(t *testing.T) { "8192", }, }, + { + name: "with runtime flags", + bundle: &mockModelBundle{ + safetensorsPath: "/path/to/model", + }, + config: &inference.BackendConfiguration{ + RuntimeFlags: []string{"--gpu-memory-utilization", "0.9"}, + }, + expected: []string{ + "serve", + "/path/to", + "--uds", + "/tmp/socket", + "--gpu-memory-utilization", + "0.9", + }, + }, { name: "with model context size (takes precedence)", bundle: &mockModelBundle{ diff --git a/pkg/inference/runtime_flags.go b/pkg/inference/runtime_flags.go new file mode 100644 index 00000000..d0712c3e --- /dev/null +++ b/pkg/inference/runtime_flags.go @@ -0,0 +1,27 @@ +package inference + +import ( + "fmt" + "strings" +) + +// ValidateRuntimeFlags ensures runtime flags don't contain paths (forward slash "/" or backslash "\") +// to prevent malicious users from overwriting host files via arguments like +// --log-file /some/path, --output-file /etc/passwd, or --log-file C:\Windows\file. +// +// This validation rejects any flag or value containing "/" or "\" to block: +// - Unix/Linux/macOS absolute paths: /var/log/file, /etc/passwd +// - Unix/Linux/macOS relative paths: ../file.txt, ./config +// - Windows absolute paths: C:\Users\file, D:\data\file +// - Windows relative paths: ..\file.txt, .\config +// - UNC paths: \\network\share\file +// +// Returns an error if any flag contains a forward slash or backslash. +func ValidateRuntimeFlags(flags []string) error { + for _, flag := range flags { + if strings.Contains(flag, "/") || strings.Contains(flag, "\\") { + return fmt.Errorf("invalid runtime flag %q: paths are not allowed (contains '/' or '\\\\')", flag) + } + } + return nil +} diff --git a/pkg/inference/runtime_flags_test.go b/pkg/inference/runtime_flags_test.go new file mode 100644 index 00000000..0f2543aa --- /dev/null +++ b/pkg/inference/runtime_flags_test.go @@ -0,0 +1,209 @@ +package inference + +import ( + "testing" +) + +func TestValidateRuntimeFlags(t *testing.T) { + tests := []struct { + name string + flags []string + expectError bool + description string + }{ + { + name: "empty flags", + flags: []string{}, + expectError: false, + description: "Empty array should pass validation", + }, + { + name: "nil flags", + flags: nil, + expectError: false, + description: "Nil array should pass validation", + }, + { + name: "valid flags without paths", + flags: []string{"--verbose", "--debug", "--threads", "4"}, + expectError: false, + description: "Simple flags without paths should pass", + }, + { + name: "valid single character flags", + flags: []string{"-v", "-d", "-t", "4"}, + expectError: false, + description: "Single character flags should pass", + }, + { + name: "valid flags with numbers and hyphens", + flags: []string{"--gpu-memory-utilization", "0.9", "--max-tokens", "1024"}, + expectError: false, + description: "Flags with hyphens and numeric values should pass", + }, + { + name: "reject absolute path in value", + flags: []string{"--log-file", "/var/log/model.log"}, + expectError: true, + description: "Absolute paths should be rejected", + }, + { + name: "reject absolute path in flag=value format", + flags: []string{"--log-file=/var/log/model.log"}, + expectError: true, + description: "Paths in flag=value format should be rejected", + }, + { + name: "reject relative path with parent directory", + flags: []string{"--output", "../file.txt"}, + expectError: true, + description: "Relative paths with ../ should be rejected", + }, + { + name: "reject relative path with current directory", + flags: []string{"--config", "./config.yaml"}, + expectError: true, + description: "Relative paths with ./ should be rejected", + }, + { + name: "reject Windows-style path with forward slash", + flags: []string{"--file", "C:/Users/file.txt"}, + expectError: true, + description: "Windows-style paths with forward slash should be rejected", + }, + { + name: "reject Windows-style path with backslash", + flags: []string{"--file", "C:\\Users\\file.txt"}, + expectError: true, + description: "Windows-style paths with backslash should be rejected", + }, + { + name: "reject Windows relative path with backslash", + flags: []string{"--config", "..\\config.yaml"}, + expectError: true, + description: "Windows relative paths with backslash should be rejected", + }, + { + name: "reject Windows current directory path", + flags: []string{"--output", ".\\output.txt"}, + expectError: true, + description: "Windows current directory paths should be rejected", + }, + { + name: "reject UNC network path", + flags: []string{"--share", "\\\\server\\share\\file.txt"}, + expectError: true, + description: "UNC network paths should be rejected", + }, + { + name: "reject Windows system path", + flags: []string{"--log", "C:\\Windows\\System32\\log.txt"}, + expectError: true, + description: "Windows system paths should be rejected", + }, + { + name: "reject URL with http", + flags: []string{"--endpoint", "http://example.com/api"}, + expectError: true, + description: "URLs should be rejected (conservative approach)", + }, + { + name: "reject URL with https", + flags: []string{"--api-url", "https://api.example.com/v1"}, + expectError: true, + description: "HTTPS URLs should be rejected (conservative approach)", + }, + { + name: "reject path in middle of flag list", + flags: []string{"--verbose", "--log-file", "/tmp/log.txt", "--debug"}, + expectError: true, + description: "Path anywhere in flag list should be rejected", + }, + { + name: "reject multiple paths", + flags: []string{"--input", "/path/to/input", "--output", "/path/to/output"}, + expectError: true, + description: "Multiple paths should be rejected", + }, + { + name: "reject path traversal attempt", + flags: []string{"--file", "../../etc/passwd"}, + expectError: true, + description: "Path traversal attempts should be rejected", + }, + { + name: "reject root directory", + flags: []string{"--root", "/"}, + expectError: true, + description: "Root directory should be rejected", + }, + { + name: "reject home directory path", + flags: []string{"--home", "/home/user/.config"}, + expectError: true, + description: "Home directory paths should be rejected", + }, + { + name: "valid flag with special characters except slash", + flags: []string{"--model-name", "llama-3.2-1b", "--temperature", "0.7"}, + expectError: false, + description: "Flags with dots, hyphens, and numbers (no slash) should pass", + }, + { + name: "valid flag with underscore", + flags: []string{"--max_tokens", "512", "--use_cache"}, + expectError: false, + description: "Flags with underscores should pass", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateRuntimeFlags(tt.flags) + + if tt.expectError { + if err == nil { + t.Errorf("%s: expected error but got none", tt.description) + } + } else { + if err != nil { + t.Errorf("%s: unexpected error: %v", tt.description, err) + } + } + }) + } +} + +func TestValidateRuntimeFlags_ErrorMessage(t *testing.T) { + // Test that error messages are helpful + flags := []string{"--log-file", "/var/log/test.log"} + err := ValidateRuntimeFlags(flags) + + if err == nil { + t.Fatal("Expected error but got none") + } + + errMsg := err.Error() + if !contains(errMsg, "/var/log/test.log") { + t.Errorf("Error message should contain the offending flag value, got: %s", errMsg) + } + if !contains(errMsg, "paths are not allowed") { + t.Errorf("Error message should explain why it failed, got: %s", errMsg) + } +} + +// contains is a helper function to check if a string contains a substring +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || substr == "" || + (s != "" && indexOf(s, substr) >= 0)) +} + +// indexOf returns the index of substr in s, or -1 if not found +func indexOf(s, substr string) int { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return i + } + } + return -1 +} diff --git a/pkg/inference/scheduling/api.go b/pkg/inference/scheduling/api.go index a1b0ca37..e899daa4 100644 --- a/pkg/inference/scheduling/api.go +++ b/pkg/inference/scheduling/api.go @@ -93,7 +93,8 @@ type UnloadResponse struct { // ConfigureRequest specifies per-model runtime configuration options. type ConfigureRequest struct { - Model string `json:"model"` - Mode *inference.BackendMode `json:"mode,omitempty"` + Model string `json:"model"` + Mode *inference.BackendMode `json:"mode,omitempty"` + RawRuntimeFlags string `json:"raw-runtime-flags,omitempty"` inference.BackendConfiguration } diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 1c29defb..fdafb1ee 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -3,7 +3,9 @@ package scheduling import ( "context" "errors" + "fmt" "net/http" + "slices" "time" "github.com/docker/model-runner/pkg/distribution/types" @@ -14,6 +16,7 @@ import ( "github.com/docker/model-runner/pkg/internal/utils" "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/metrics" + "github.com/mattn/go-shellwords" "golang.org/x/sync/errgroup" ) @@ -225,10 +228,28 @@ func (s *Scheduler) ConfigureRunner(ctx context.Context, backend inference.Backe backend = s.defaultBackend } + // Parse runtime flags from either array or raw string + var runtimeFlags []string + if len(req.RuntimeFlags) > 0 { + runtimeFlags = req.RuntimeFlags + } else if req.RawRuntimeFlags != "" { + var err error + runtimeFlags, err = shellwords.Parse(req.RawRuntimeFlags) + if err != nil { + return nil, fmt.Errorf("invalid runtime flags: %w", err) + } + } + + // Validate runtime flags to prevent path-based security issues + if err := inference.ValidateRuntimeFlags(runtimeFlags); err != nil { + return nil, err + } + // Build runner configuration with shared settings var runnerConfig inference.BackendConfiguration runnerConfig.ContextSize = req.ContextSize runnerConfig.Speculative = req.Speculative + runnerConfig.RuntimeFlags = runtimeFlags // Set vLLM-specific configuration if provided if req.VLLM != nil { @@ -255,6 +276,8 @@ func (s *Scheduler) ConfigureRunner(ctx context.Context, backend inference.Backe mode := inference.BackendModeCompletion if req.Mode != nil { mode = *req.Mode + } else if slices.Contains(runnerConfig.RuntimeFlags, "--embeddings") { + mode = inference.BackendModeEmbedding } // Get model, track usage, and select appropriate backend