Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions cmd/cli/commands/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func newComposeCmd() *cobra.Command {
func newUpCommand() *cobra.Command {
var models []string
var ctxSize int64
var rawRuntimeFlags string
var backend string
var draftModel string
var numTokens int
Expand Down Expand Up @@ -69,6 +70,9 @@ func newUpCommand() *cobra.Command {
if ctxSize > 0 {
sendInfo(fmt.Sprintf("Setting context size to %d", ctxSize))
}
if rawRuntimeFlags != "" {
sendInfo("Setting raw runtime flags to " + rawRuntimeFlags)
}

// Build speculative config if any speculative flags are set
var speculativeConfig *inference.SpeculativeDecodingConfig
Expand All @@ -89,10 +93,11 @@ func newUpCommand() *cobra.Command {
ContextSize: &size,
Speculative: speculativeConfig,
},
RawRuntimeFlags: rawRuntimeFlags,
}); err != nil {
configErrFmtString := "failed to configure backend for model %s with context-size %d"
_ = sendErrorf(configErrFmtString+": %v", model, ctxSize, err)
return fmt.Errorf(configErrFmtString+": %w", model, ctxSize, err)
configErrFmtString := "failed to configure backend for model %s with context-size %d and runtime-flags %s"
_ = sendErrorf(configErrFmtString+": %v", model, ctxSize, rawRuntimeFlags, err)
return fmt.Errorf(configErrFmtString+": %w", model, ctxSize, rawRuntimeFlags, err)
}
sendInfo("Successfully configured backend for model " + model)
}
Expand All @@ -114,6 +119,7 @@ func newUpCommand() *cobra.Command {
}
c.Flags().StringArrayVar(&models, "model", nil, "model to use")
c.Flags().Int64Var(&ctxSize, "context-size", -1, "context size for the model")
c.Flags().StringVar(&rawRuntimeFlags, "runtime-flags", "", "raw runtime flags to pass to the inference engine")
c.Flags().StringVar(&backend, "backend", llamacpp.Name, "inference backend to use")
c.Flags().StringVar(&draftModel, "speculative-draft-model", "", "draft model for speculative decoding")
c.Flags().IntVar(&numTokens, "speculative-num-tokens", 0, "number of tokens to predict speculatively")
Expand Down
25 changes: 19 additions & 6 deletions cmd/cli/commands/configure.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,27 @@ func newConfigureCmd() *cobra.Command {
var flags ConfigureFlags

c := &cobra.Command{
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL",
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL [-- <runtime-flags...>]",
Short: "Configure runtime options for a model",
Hidden: true,
Args: func(cmd *cobra.Command, args []string) error {
if len(args) != 1 {
return fmt.Errorf(
"Exactly one model must be specified, got %d: %v\n\n"+
"See 'docker model configure --help' for more information",
len(args), args)
argsBeforeDash := cmd.ArgsLenAtDash()
if argsBeforeDash == -1 {
// No "--" used, so we need exactly 1 total argument.
if len(args) != 1 {
return fmt.Errorf(
"Exactly one model must be specified, got %d: %v\n\n"+
"See 'docker model configure --help' for more information",
len(args), args)
}
} else {
// Has "--", so we need exactly 1 argument before it.
if argsBeforeDash != 1 {
return fmt.Errorf(
"Exactly one model must be specified before --, got %d\n\n"+
"See 'docker model configure --help' for more information",
argsBeforeDash)
}
}
return nil
},
Expand All @@ -29,6 +41,7 @@ func newConfigureCmd() *cobra.Command {
if err != nil {
return err
}
opts.RuntimeFlags = args[1:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Runtime flags slice includes the model name and likely the literal "--"

Here opts.RuntimeFlags = args[1:] will still include the literal "--" when present (e.g. docker model configure foo -- --embeddings yields RuntimeFlags = ["--", "--embeddings"]). To avoid depending on positional slicing and to drop "--" explicitly, consider mirroring the Args logic and using cmd.ArgsLenAtDash(): treat args[:argsBeforeDash] as the model (length 1) and args[argsBeforeDash+1:] as runtime flags.

return desktopClient.ConfigureBackend(opts)
},
ValidArgsFunction: completion.ModelNames(getDesktopClient, -1),
Expand Down
87 changes: 87 additions & 0 deletions cmd/cli/commands/configure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,90 @@ func TestThinkFlagBehavior(t *testing.T) {
})
}
}

func TestRuntimeFlagsValidation(t *testing.T) {
tests := []struct {
name string
runtimeFlags []string
expectError bool
errorContains string
}{
{
name: "valid runtime flags without paths",
runtimeFlags: []string{"--verbose", "--threads", "4"},
expectError: false,
},
{
name: "empty runtime flags",
runtimeFlags: []string{},
expectError: false,
},
{
name: "reject absolute path in value",
runtimeFlags: []string{"--log-file", "/var/log/model.log"},
expectError: true,
errorContains: "paths are not allowed",
},
{
name: "reject absolute path in flag=value format",
runtimeFlags: []string{"--output-file=/tmp/output.txt"},
expectError: true,
errorContains: "paths are not allowed",
},
{
name: "reject relative path",
runtimeFlags: []string{"--config", "../config.yaml"},
expectError: true,
errorContains: "paths are not allowed",
},
{
name: "reject URL",
runtimeFlags: []string{"--endpoint", "http://example.com/api"},
expectError: true,
errorContains: "paths are not allowed",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
flags := ConfigureFlags{}
req, err := flags.BuildConfigureRequest("test-model")
if err != nil {
t.Fatalf("BuildConfigureRequest failed: %v", err)
}

// Set runtime flags after building request
req.RuntimeFlags = tt.runtimeFlags

// Note: The actual validation happens in scheduler.ConfigureRunner,
// but we're testing that the BuildConfigureRequest correctly
// preserves the RuntimeFlags for validation downstream.
// For a true integration test, we would need to mock the scheduler.

if tt.expectError {
// In this unit test context, we verify the flags are preserved
// The actual validation will happen in the scheduler
if len(req.RuntimeFlags) == 0 && len(tt.runtimeFlags) > 0 {
t.Error("RuntimeFlags should be preserved in the request")
}
} else {
if !equalStringSlices(req.RuntimeFlags, tt.runtimeFlags) {
t.Errorf("Expected RuntimeFlags %v, got %v", tt.runtimeFlags, req.RuntimeFlags)
}
}
})
}
}

// equalStringSlices checks if two string slices are equal
func equalStringSlices(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
9 changes: 9 additions & 0 deletions cmd/cli/docs/reference/docker_model_compose_up.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ options:
experimentalcli: false
kubernetes: false
swarm: false
- option: runtime-flags
value_type: string
description: raw runtime flags to pass to the inference engine
deprecated: false
hidden: false
experimental: false
experimentalcli: false
kubernetes: false
swarm: false
- option: speculative-draft-model
value_type: string
description: draft model for speculative decoding
Expand Down
2 changes: 1 addition & 1 deletion cmd/cli/docs/reference/docker_model_configure.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
command: docker model configure
short: Configure runtime options for a model
long: Configure runtime options for a model
usage: docker model configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL
usage: docker model configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL [-- <runtime-flags...>]
pname: docker model
plink: docker_model.yaml
options:
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ require (
github.com/gpustack/gguf-parser-go v0.22.1
github.com/jaypipes/ghw v0.19.1
github.com/kolesnikovae/go-winjob v1.0.0
github.com/mattn/go-shellwords v1.0.12
github.com/opencontainers/go-digest v1.0.0
github.com/opencontainers/image-spec v1.1.1
github.com/prometheus/client_model v0.6.2
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-shellwords v1.0.12 h1:M2zGm7EW6UQJvDeQxo4T51eKPurbeFbe8WtebGE2xrk=
github.com/mattn/go-shellwords v1.0.12/go.mod h1:EZzvwXDESEeg03EKmM+RmDnNOPKG4lLtQsUlTZDWQ8Y=
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg=
Expand Down
5 changes: 3 additions & 2 deletions pkg/inference/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ type LlamaCppConfig struct {

type BackendConfiguration struct {
// Shared configuration across all backends
ContextSize *int32 `json:"context-size,omitempty"`
Speculative *SpeculativeDecodingConfig `json:"speculative,omitempty"`
ContextSize *int32 `json:"context-size,omitempty"`
RuntimeFlags []string `json:"runtime-flags,omitempty"`
Speculative *SpeculativeDecodingConfig `json:"speculative,omitempty"`

// Backend-specific configuration
VLLM *VLLMConfig `json:"vllm,omitempty"`
Expand Down
5 changes: 5 additions & 0 deletions pkg/inference/backends/llamacpp/llamacpp_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
args = append(args, "--ctx-size", strconv.FormatInt(int64(*contextSize), 10))
}

// Add arguments from backend config
if config != nil {
args = append(args, config.RuntimeFlags...)
}

// Add arguments for Multimodal projector or jinja (they are mutually exclusive)
if path := bundle.MMPROJPath(); path != "" {
args = append(args, "--mmproj", path)
Expand Down
17 changes: 17 additions & 0 deletions pkg/inference/backends/llamacpp/llamacpp_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,23 @@ func TestGetArgs(t *testing.T) {
"--jinja",
),
},
{
name: "raw flags from backend config",
mode: inference.BackendModeEmbedding,
bundle: &fakeBundle{
ggufPath: modelPath,
},
config: &inference.BackendConfiguration{
RuntimeFlags: []string{"--some", "flag"},
},
expected: append(slices.Clone(baseArgs),
"--model", modelPath,
"--host", socket,
"--embeddings",
"--some", "flag",
"--jinja",
),
},
{
name: "multimodal projector removes jinja",
mode: inference.BackendModeCompletion,
Expand Down
6 changes: 5 additions & 1 deletion pkg/inference/backends/vllm/vllm_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ func (c *Config) GetArgs(bundle types.ModelBundle, socket string, mode inference
if maxLen := GetMaxModelLen(bundle.RuntimeConfig(), config); maxLen != nil {
args = append(args, "--max-model-len", strconv.FormatInt(int64(*maxLen), 10))
}
// If nil, vLLM will automatically derive from the model config

// Add runtime flags from backend config
if config != nil {
args = append(args, config.RuntimeFlags...)
}

// Add vLLM-specific arguments from backend config
if config != nil && config.VLLM != nil {
Expand Down
17 changes: 17 additions & 0 deletions pkg/inference/backends/vllm/vllm_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,23 @@ func TestGetArgs(t *testing.T) {
"8192",
},
},
{
name: "with runtime flags",
bundle: &mockModelBundle{
safetensorsPath: "/path/to/model",
},
config: &inference.BackendConfiguration{
RuntimeFlags: []string{"--gpu-memory-utilization", "0.9"},
},
expected: []string{
"serve",
"/path/to",
"--uds",
"/tmp/socket",
"--gpu-memory-utilization",
"0.9",
},
},
{
name: "with model context size (takes precedence)",
bundle: &mockModelBundle{
Expand Down
27 changes: 27 additions & 0 deletions pkg/inference/runtime_flags.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package inference

import (
"fmt"
"strings"
)

// ValidateRuntimeFlags ensures runtime flags don't contain paths (forward slash "/" or backslash "\")
// to prevent malicious users from overwriting host files via arguments like
// --log-file /some/path, --output-file /etc/passwd, or --log-file C:\Windows\file.
//
// This validation rejects any flag or value containing "/" or "\" to block:
// - Unix/Linux/macOS absolute paths: /var/log/file, /etc/passwd
// - Unix/Linux/macOS relative paths: ../file.txt, ./config
// - Windows absolute paths: C:\Users\file, D:\data\file
// - Windows relative paths: ..\file.txt, .\config
// - UNC paths: \\network\share\file
//
// Returns an error if any flag contains a forward slash or backslash.
func ValidateRuntimeFlags(flags []string) error {
for _, flag := range flags {
if strings.Contains(flag, "/") || strings.Contains(flag, "\\") {
return fmt.Errorf("invalid runtime flag %q: paths are not allowed (contains '/' or '\\\\')", flag)
}
}
return nil
}
Loading
Loading