diff --git a/go.mod b/go.mod index 91ce38df1..bf3e67d90 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/jaypipes/pcidb v1.0.1 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.11 // indirect + github.com/mattn/go-shellwords v1.0.12 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/moby/locker v1.0.1 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect diff --git a/go.sum b/go.sum index 2333131bf..8f2290d64 100644 --- a/go.sum +++ b/go.sum @@ -75,6 +75,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-shellwords v1.0.12 h1:M2zGm7EW6UQJvDeQxo4T51eKPurbeFbe8WtebGE2xrk= +github.com/mattn/go-shellwords v1.0.12/go.mod h1:EZzvwXDESEeg03EKmM+RmDnNOPKG4lLtQsUlTZDWQ8Y= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg= diff --git a/pkg/inference/backend.go b/pkg/inference/backend.go index 37ad7d0ab..48676626c 100644 --- a/pkg/inference/backend.go +++ b/pkg/inference/backend.go @@ -29,6 +29,11 @@ func (m BackendMode) String() string { } } +type BackendConfiguration struct { + ContextSize int64 + RawFlags []string +} + // Backend is the interface implemented by inference engine backends. Backend // implementations need not be safe for concurrent invocation of the following // methods, though their underlying server implementations do need to support @@ -66,7 +71,7 @@ type Backend interface { // to be loaded. Backends should not load multiple models at once and should // instead load only the specified model. Backends should still respond to // OpenAI API requests for other models with a 421 error code. - Run(ctx context.Context, socket, model string, mode BackendMode) error + Run(ctx context.Context, socket, model string, mode BackendMode, config *BackendConfiguration) error // Status returns a description of the backend's state. Status() string // GetDiskUsage returns the disk usage of the backend. diff --git a/pkg/inference/backends/llamacpp/llamacpp.go b/pkg/inference/backends/llamacpp/llamacpp.go index 09cd91b41..930535daa 100644 --- a/pkg/inference/backends/llamacpp/llamacpp.go +++ b/pkg/inference/backends/llamacpp/llamacpp.go @@ -11,6 +11,7 @@ import ( "os/exec" "path/filepath" "runtime" + "strconv" "strings" "github.com/docker/model-runner/pkg/diskusage" @@ -120,7 +121,7 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error { } // Run implements inference.Backend.Run. -func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode) error { +func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error { modelPath, err := l.modelManager.GetModelPath(model) l.log.Infof("Model path: %s", modelPath) if err != nil { @@ -138,6 +139,14 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference } args := l.config.GetArgs(modelPath, socket, mode) + + if config != nil { + if config.ContextSize >= 0 { + args = append(args, "--ctx-size", strconv.Itoa(int(config.ContextSize))) + } + args = append(args, config.RawFlags...) + } + l.log.Infof("llamaCppArgs: %v", args) llamaCppProcess := exec.CommandContext( ctx, diff --git a/pkg/inference/backends/mlx/mlx.go b/pkg/inference/backends/mlx/mlx.go index 7778dfbaa..d6cf86e09 100644 --- a/pkg/inference/backends/mlx/mlx.go +++ b/pkg/inference/backends/mlx/mlx.go @@ -49,7 +49,7 @@ func (m *mlx) Install(ctx context.Context, httpClient *http.Client) error { } // Run implements inference.Backend.Run. -func (m *mlx) Run(ctx context.Context, socket, model string, mode inference.BackendMode) error { +func (m *mlx) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error { // TODO: Implement. m.log.Warn("MLX backend is not yet supported") return errors.New("not implemented") diff --git a/pkg/inference/backends/vllm/vllm.go b/pkg/inference/backends/vllm/vllm.go index 9ea8bcf83..c03c367ad 100644 --- a/pkg/inference/backends/vllm/vllm.go +++ b/pkg/inference/backends/vllm/vllm.go @@ -49,7 +49,7 @@ func (v *vLLM) Install(ctx context.Context, httpClient *http.Client) error { } // Run implements inference.Backend.Run. -func (v *vLLM) Run(ctx context.Context, socket, model string, mode inference.BackendMode) error { +func (v *vLLM) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error { // TODO: Implement. v.log.Warn("vLLM backend is not yet supported") return errors.New("not implemented") diff --git a/pkg/inference/scheduling/api.go b/pkg/inference/scheduling/api.go index e574a950e..ebd40571d 100644 --- a/pkg/inference/scheduling/api.go +++ b/pkg/inference/scheduling/api.go @@ -73,3 +73,10 @@ type UnloadRequest struct { type UnloadResponse struct { UnloadedRunners int `json:"unloaded_runners"` } + +// ConfigureRequest specifies per-model runtime configuration options. +type ConfigureRequest struct { + Model string `json:"model"` + ContextSize int64 `json:"context-size,omitempty"` + RawRuntimeFlags string `json:"raw-runtime-flags,omitempty"` +} diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go index 59a2e8dea..60c1eec49 100644 --- a/pkg/inference/scheduling/loader.go +++ b/pkg/inference/scheduling/loader.go @@ -29,6 +29,9 @@ var ( // errModelTooBig indicates that the model is too big to ever load into the // available system memory. errModelTooBig = errors.New("model too big") + // errRunnerAlreadyActive indicates that a given runner is already active + // and therefore can't be reconfigured for example + errRunnerAlreadyActive = errors.New("runner already active") ) // runnerKey is used to index runners. @@ -82,6 +85,8 @@ type loader struct { // timestamps maps slot indices to last usage times. Values in this slice // are only valid if the corresponding reference count is zero. timestamps []time.Time + // runnerConfigs maps model names to runner configurations + runnerConfigs map[runnerKey]inference.BackendConfiguration } // newLoader creates a new loader. @@ -122,6 +127,7 @@ func newLoader( references: make([]uint, nSlots), allocations: make([]uint64, nSlots), timestamps: make([]time.Time, nSlots), + runnerConfigs: make(map[runnerKey]inference.BackendConfiguration), } l.guard <- struct{}{} return l @@ -214,9 +220,11 @@ func (l *loader) Unload(ctx context.Context, unload UnloadRequest) int { return len(l.runners) - func() int { if unload.All { + l.runnerConfigs = make(map[runnerKey]inference.BackendConfiguration) return l.evict(false) } else { for _, model := range unload.Models { + delete(l.runnerConfigs, runnerKey{unload.Backend, model, inference.BackendModeCompletion}) // Evict both, completion and embedding models. We should consider // accepting a mode parameter in unload requests. l.evictRunner(unload.Backend, model, inference.BackendModeCompletion) @@ -413,9 +421,13 @@ func (l *loader) load(ctx context.Context, backendName, model string, mode infer // If we've identified a slot, then we're ready to start a runner. if slot >= 0 { + var runnerConfig *inference.BackendConfiguration + if rc, ok := l.runnerConfigs[runnerKey{backendName, model, mode}]; ok { + runnerConfig = &rc + } // Create the runner. l.log.Infof("Loading %s backend runner with model %s in %s mode", backendName, model, mode) - runner, err := run(l.log, backend, model, mode, slot) + runner, err := run(l.log, backend, model, mode, slot, runnerConfig) if err != nil { l.log.Warnf("Unable to start %s backend runner with model %s in %s mode: %v", backendName, model, mode, err, @@ -492,3 +504,18 @@ func (l *loader) release(runner *runner) { // Signal waiters. l.broadcast() } + +func (l *loader) setRunnerConfig(ctx context.Context, backendName, model string, mode inference.BackendMode, runnerConfig inference.BackendConfiguration) error { + l.lock(ctx) + defer l.unlock() + + runnerId := runnerKey{backendName, model, mode} + + if _, ok := l.runners[runnerId]; ok { + return errRunnerAlreadyActive + } + + l.log.Infof("Configuring %s runner for %s", backendName, model) + l.runnerConfigs[runnerId] = runnerConfig + return nil +} diff --git a/pkg/inference/scheduling/runner.go b/pkg/inference/scheduling/runner.go index ea0dd1495..43f28e48e 100644 --- a/pkg/inference/scheduling/runner.go +++ b/pkg/inference/scheduling/runner.go @@ -73,6 +73,7 @@ func run( model string, mode inference.BackendMode, slot int, + runnerConfig *inference.BackendConfiguration, ) (*runner, error) { // Create a dialer / transport that target backend on the specified slot. socket, err := RunnerSocketPath(slot) @@ -152,7 +153,7 @@ func run( // Start the backend run loop. go func() { - if err := backend.Run(runCtx, socket, model, mode); err != nil { + if err := backend.Run(runCtx, socket, model, mode, runnerConfig); err != nil { log.Warnf("Backend %s running model %s exited with error: %v", backend.Name(), model, err, ) diff --git a/pkg/inference/scheduling/scheduler.go b/pkg/inference/scheduling/scheduler.go index 40e536ed6..8fbe721d8 100644 --- a/pkg/inference/scheduling/scheduler.go +++ b/pkg/inference/scheduling/scheduler.go @@ -17,6 +17,7 @@ import ( "github.com/docker/model-runner/pkg/inference/models" "github.com/docker/model-runner/pkg/logging" "github.com/docker/model-runner/pkg/metrics" + "github.com/mattn/go-shellwords" "golang.org/x/sync/errgroup" ) @@ -112,6 +113,8 @@ func (s *Scheduler) routeHandlers(allowedOrigins []string) map[string]http.Handl m["GET "+inference.InferencePrefix+"/ps"] = s.GetRunningBackends m["GET "+inference.InferencePrefix+"/df"] = s.GetDiskUsage m["POST "+inference.InferencePrefix+"/unload"] = s.Unload + m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = s.Configure + m["POST "+inference.InferencePrefix+"/_configure"] = s.Configure return m } @@ -347,6 +350,61 @@ func (s *Scheduler) Unload(w http.ResponseWriter, r *http.Request) { } } +func (s *Scheduler) Configure(w http.ResponseWriter, r *http.Request) { + // Determine the requested backend and ensure that it's valid. + var backend inference.Backend + if b := r.PathValue("backend"); b == "" { + backend = s.defaultBackend + } else { + backend = s.backends[b] + } + if backend == nil { + http.Error(w, ErrBackendNotFound.Error(), http.StatusNotFound) + return + } + + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize)) + if err != nil { + if _, ok := err.(*http.MaxBytesError); ok { + http.Error(w, "request too large", http.StatusBadRequest) + } else { + http.Error(w, "unknown error", http.StatusInternalServerError) + } + return + } + + configureRequest := ConfigureRequest{ + Model: "", + ContextSize: -1, + RawRuntimeFlags: "", + } + if err := json.Unmarshal(body, &configureRequest); err != nil { + http.Error(w, "invalid request", http.StatusBadRequest) + return + } + rawFlags, err := shellwords.Parse(configureRequest.RawRuntimeFlags) + if err != nil { + http.Error(w, "invalid request", http.StatusBadRequest) + return + } + + var runnerConfig inference.BackendConfiguration + runnerConfig.ContextSize = configureRequest.ContextSize + runnerConfig.RawFlags = rawFlags + + if err := s.loader.setRunnerConfig(r.Context(), backend.Name(), configureRequest.Model, inference.BackendModeCompletion, runnerConfig); err != nil { + s.log.Warnf("Failed to configure %s runner for %s: %s", backend.Name(), configureRequest.Model, err) + if errors.Is(err, errRunnerAlreadyActive) { + http.Error(w, err.Error(), http.StatusConflict) + } else { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + + w.WriteHeader(http.StatusAccepted) +} + // ServeHTTP implements net/http.Handler.ServeHTTP. func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.lock.Lock()