Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ require (
github.com/jaypipes/pcidb v1.0.1 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.17.11 // indirect
github.com/mattn/go-shellwords v1.0.12 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/moby/locker v1.0.1 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-shellwords v1.0.12 h1:M2zGm7EW6UQJvDeQxo4T51eKPurbeFbe8WtebGE2xrk=
github.com/mattn/go-shellwords v1.0.12/go.mod h1:EZzvwXDESEeg03EKmM+RmDnNOPKG4lLtQsUlTZDWQ8Y=
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg=
Expand Down
7 changes: 6 additions & 1 deletion pkg/inference/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ func (m BackendMode) String() string {
}
}

type BackendConfiguration struct {
ContextSize int64
RawFlags []string
}

// Backend is the interface implemented by inference engine backends. Backend
// implementations need not be safe for concurrent invocation of the following
// methods, though their underlying server implementations do need to support
Expand Down Expand Up @@ -66,7 +71,7 @@ type Backend interface {
// to be loaded. Backends should not load multiple models at once and should
// instead load only the specified model. Backends should still respond to
// OpenAI API requests for other models with a 421 error code.
Run(ctx context.Context, socket, model string, mode BackendMode) error
Run(ctx context.Context, socket, model string, mode BackendMode, config *BackendConfiguration) error
// Status returns a description of the backend's state.
Status() string
// GetDiskUsage returns the disk usage of the backend.
Expand Down
11 changes: 10 additions & 1 deletion pkg/inference/backends/llamacpp/llamacpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"os/exec"
"path/filepath"
"runtime"
"strconv"
"strings"

"github.com/docker/model-runner/pkg/diskusage"
Expand Down Expand Up @@ -120,7 +121,7 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error {
}

// Run implements inference.Backend.Run.
func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode) error {
func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
modelPath, err := l.modelManager.GetModelPath(model)
l.log.Infof("Model path: %s", modelPath)
if err != nil {
Expand All @@ -138,6 +139,14 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
}

args := l.config.GetArgs(modelPath, socket, mode)

if config != nil {
if config.ContextSize >= 0 {
args = append(args, "--ctx-size", strconv.Itoa(int(config.ContextSize)))
}
args = append(args, config.RawFlags...)
}

l.log.Infof("llamaCppArgs: %v", args)
llamaCppProcess := exec.CommandContext(
ctx,
Expand Down
2 changes: 1 addition & 1 deletion pkg/inference/backends/mlx/mlx.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (m *mlx) Install(ctx context.Context, httpClient *http.Client) error {
}

// Run implements inference.Backend.Run.
func (m *mlx) Run(ctx context.Context, socket, model string, mode inference.BackendMode) error {
func (m *mlx) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
// TODO: Implement.
m.log.Warn("MLX backend is not yet supported")
return errors.New("not implemented")
Expand Down
2 changes: 1 addition & 1 deletion pkg/inference/backends/vllm/vllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func (v *vLLM) Install(ctx context.Context, httpClient *http.Client) error {
}

// Run implements inference.Backend.Run.
func (v *vLLM) Run(ctx context.Context, socket, model string, mode inference.BackendMode) error {
func (v *vLLM) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
// TODO: Implement.
v.log.Warn("vLLM backend is not yet supported")
return errors.New("not implemented")
Expand Down
7 changes: 7 additions & 0 deletions pkg/inference/scheduling/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,10 @@ type UnloadRequest struct {
type UnloadResponse struct {
UnloadedRunners int `json:"unloaded_runners"`
}

// ConfigureRequest specifies per-model runtime configuration options.
type ConfigureRequest struct {
Model string `json:"model"`
ContextSize int64 `json:"context-size,omitempty"`
RawRuntimeFlags string `json:"raw-runtime-flags,omitempty"`
}
29 changes: 28 additions & 1 deletion pkg/inference/scheduling/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
// errModelTooBig indicates that the model is too big to ever load into the
// available system memory.
errModelTooBig = errors.New("model too big")
// errRunnerAlreadyActive indicates that a given runner is already active
// and therefore can't be reconfigured for example
errRunnerAlreadyActive = errors.New("runner already active")
)

// runnerKey is used to index runners.
Expand Down Expand Up @@ -82,6 +85,8 @@
// timestamps maps slot indices to last usage times. Values in this slice
// are only valid if the corresponding reference count is zero.
timestamps []time.Time
// runnerConfigs maps model names to runner configurations
runnerConfigs map[runnerKey]inference.BackendConfiguration
}

// newLoader creates a new loader.
Expand Down Expand Up @@ -122,6 +127,7 @@
references: make([]uint, nSlots),
allocations: make([]uint64, nSlots),
timestamps: make([]time.Time, nSlots),
runnerConfigs: make(map[runnerKey]inference.BackendConfiguration),
}
l.guard <- struct{}{}
return l
Expand Down Expand Up @@ -214,9 +220,11 @@

return len(l.runners) - func() int {
if unload.All {
l.runnerConfigs = make(map[runnerKey]inference.BackendConfiguration)
return l.evict(false)
} else {
for _, model := range unload.Models {
delete(l.runnerConfigs, runnerKey{unload.Backend, model, inference.BackendModeCompletion})
// Evict both, completion and embedding models. We should consider
// accepting a mode parameter in unload requests.
l.evictRunner(unload.Backend, model, inference.BackendModeCompletion)
Expand Down Expand Up @@ -413,9 +421,13 @@

// If we've identified a slot, then we're ready to start a runner.
if slot >= 0 {
var runnerConfig *inference.BackendConfiguration
if rc, ok := l.runnerConfigs[runnerKey{backendName, model, mode}]; ok {
runnerConfig = &rc
}
// Create the runner.
l.log.Infof("Loading %s backend runner with model %s in %s mode", backendName, model, mode)
runner, err := run(l.log, backend, model, mode, slot)
runner, err := run(l.log, backend, model, mode, slot, runnerConfig)
if err != nil {
l.log.Warnf("Unable to start %s backend runner with model %s in %s mode: %v",
backendName, model, mode, err,
Expand Down Expand Up @@ -492,3 +504,18 @@
// Signal waiters.
l.broadcast()
}

func (l *loader) setRunnerConfig(ctx context.Context, backendName, model string, mode inference.BackendMode, runnerConfig inference.BackendConfiguration) error {
l.lock(ctx)
defer l.unlock()

runnerId := runnerKey{backendName, model, mode}

if _, ok := l.runners[runnerId]; ok {
return errRunnerAlreadyActive
}

l.log.Infof("Configuring %s runner for %s", backendName, model)

Check failure

Code scanning / CodeQL

Log entries created from user input High

This log entry depends on a
user-provided value
.

Copilot Autofix

AI 7 months ago

To fix the issue, the user-provided model value should be sanitized before being logged. Since the log entries are plain text, we can remove newline characters (\n and \r) from the model string using strings.ReplaceAll. This ensures that malicious input cannot introduce new log entries or otherwise manipulate the log format.

The fix involves modifying the setRunnerConfig method in loader.go to sanitize the model parameter before logging it. Specifically:

  1. Use strings.ReplaceAll to remove \n and \r characters from the model string.
  2. Log the sanitized version of the model string.

Suggested changeset 1
pkg/inference/scheduling/loader.go

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go
--- a/pkg/inference/scheduling/loader.go
+++ b/pkg/inference/scheduling/loader.go
@@ -517,3 +517,5 @@
 
-	l.log.Infof("Configuring %s runner for %s", backendName, model)
+	sanitizedModel := strings.ReplaceAll(model, "\n", "")
+	sanitizedModel = strings.ReplaceAll(sanitizedModel, "\r", "")
+	l.log.Infof("Configuring %s runner for %s", backendName, sanitizedModel)
 	l.runnerConfigs[runnerId] = runnerConfig
EOF
@@ -517,3 +517,5 @@

l.log.Infof("Configuring %s runner for %s", backendName, model)
sanitizedModel := strings.ReplaceAll(model, "\n", "")
sanitizedModel = strings.ReplaceAll(sanitizedModel, "\r", "")
l.log.Infof("Configuring %s runner for %s", backendName, sanitizedModel)
l.runnerConfigs[runnerId] = runnerConfig
Copilot is powered by AI and may make mistakes. Always verify output.
l.runnerConfigs[runnerId] = runnerConfig
return nil
}
3 changes: 2 additions & 1 deletion pkg/inference/scheduling/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ func run(
model string,
mode inference.BackendMode,
slot int,
runnerConfig *inference.BackendConfiguration,
) (*runner, error) {
// Create a dialer / transport that target backend on the specified slot.
socket, err := RunnerSocketPath(slot)
Expand Down Expand Up @@ -152,7 +153,7 @@ func run(

// Start the backend run loop.
go func() {
if err := backend.Run(runCtx, socket, model, mode); err != nil {
if err := backend.Run(runCtx, socket, model, mode, runnerConfig); err != nil {
log.Warnf("Backend %s running model %s exited with error: %v",
backend.Name(), model, err,
)
Expand Down
58 changes: 58 additions & 0 deletions pkg/inference/scheduling/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"github.com/docker/model-runner/pkg/inference/models"
"github.com/docker/model-runner/pkg/logging"
"github.com/docker/model-runner/pkg/metrics"
"github.com/mattn/go-shellwords"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -112,6 +113,8 @@
m["GET "+inference.InferencePrefix+"/ps"] = s.GetRunningBackends
m["GET "+inference.InferencePrefix+"/df"] = s.GetDiskUsage
m["POST "+inference.InferencePrefix+"/unload"] = s.Unload
m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = s.Configure
m["POST "+inference.InferencePrefix+"/_configure"] = s.Configure
return m
}

Expand Down Expand Up @@ -347,6 +350,61 @@
}
}

func (s *Scheduler) Configure(w http.ResponseWriter, r *http.Request) {
// Determine the requested backend and ensure that it's valid.
var backend inference.Backend
if b := r.PathValue("backend"); b == "" {
backend = s.defaultBackend
} else {
backend = s.backends[b]
}
if backend == nil {
http.Error(w, ErrBackendNotFound.Error(), http.StatusNotFound)
return
}

body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
if err != nil {
if _, ok := err.(*http.MaxBytesError); ok {
http.Error(w, "request too large", http.StatusBadRequest)
} else {
http.Error(w, "unknown error", http.StatusInternalServerError)
}
return
}

configureRequest := ConfigureRequest{
Model: "",
ContextSize: -1,
RawRuntimeFlags: "",
}
if err := json.Unmarshal(body, &configureRequest); err != nil {
http.Error(w, "invalid request", http.StatusBadRequest)
return
}
rawFlags, err := shellwords.Parse(configureRequest.RawRuntimeFlags)
if err != nil {
http.Error(w, "invalid request", http.StatusBadRequest)
return
}

var runnerConfig inference.BackendConfiguration
runnerConfig.ContextSize = configureRequest.ContextSize
runnerConfig.RawFlags = rawFlags

if err := s.loader.setRunnerConfig(r.Context(), backend.Name(), configureRequest.Model, inference.BackendModeCompletion, runnerConfig); err != nil {
s.log.Warnf("Failed to configure %s runner for %s: %s", backend.Name(), configureRequest.Model, err)

Check failure

Code scanning / CodeQL

Log entries created from user input High

This log entry depends on a
user-provided value
.

Copilot Autofix

AI 7 months ago

Copilot could not generate an autofix suggestion

Copilot could not generate an autofix suggestion for this alert. Try pushing a new commit or if the problem persists contact support.

if errors.Is(err, errRunnerAlreadyActive) {
http.Error(w, err.Error(), http.StatusConflict)
} else {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
return
}

w.WriteHeader(http.StatusAccepted)
}

// ServeHTTP implements net/http.Handler.ServeHTTP.
func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
s.lock.Lock()
Expand Down