Skip to content

Commit 099b122

Browse files
author
Piotr Stankiewicz
committed
Support runner configuration (temporary solution)
We need to allow users to configure the model runtime. Whether to control inference settings, or low-level llama.cpp specific settings. In the interest of unblocking users quickly, this patch adds a very simple mechanism to configure the runtime settings. A `_configure` endpoint is added per-engine, and acceps POST requests to set context-size and raw runtime CLI flags. Those settings will be applied to any run of a given model, until unload is called for that model or model-runner is terminated. This is a temporary solution and therefore subject to change once a design for specifying runtime settings is finalised. Signed-off-by: Piotr Stankiewicz <piotr.stankiewicz@docker.com>
1 parent 0130eb6 commit 099b122

File tree

10 files changed

+116
-6
lines changed

10 files changed

+116
-6
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ require (
3535
github.com/jaypipes/pcidb v1.0.1 // indirect
3636
github.com/json-iterator/go v1.1.12 // indirect
3737
github.com/klauspost/compress v1.17.11 // indirect
38+
github.com/mattn/go-shellwords v1.0.12 // indirect
3839
github.com/mitchellh/go-homedir v1.1.0 // indirect
3940
github.com/moby/locker v1.0.1 // indirect
4041
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
7575
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
7676
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
7777
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
78+
github.com/mattn/go-shellwords v1.0.12 h1:M2zGm7EW6UQJvDeQxo4T51eKPurbeFbe8WtebGE2xrk=
79+
github.com/mattn/go-shellwords v1.0.12/go.mod h1:EZzvwXDESEeg03EKmM+RmDnNOPKG4lLtQsUlTZDWQ8Y=
7880
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
7981
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
8082
github.com/moby/locker v1.0.1 h1:fOXqR41zeveg4fFODix+1Ch4mj/gT0NE1XJbp/epuBg=

pkg/inference/backend.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ func (m BackendMode) String() string {
2929
}
3030
}
3131

32+
type BackendConfiguration struct {
33+
ContextSize int64
34+
RawFlags []string
35+
}
36+
3237
// Backend is the interface implemented by inference engine backends. Backend
3338
// implementations need not be safe for concurrent invocation of the following
3439
// methods, though their underlying server implementations do need to support
@@ -66,7 +71,7 @@ type Backend interface {
6671
// to be loaded. Backends should not load multiple models at once and should
6772
// instead load only the specified model. Backends should still respond to
6873
// OpenAI API requests for other models with a 421 error code.
69-
Run(ctx context.Context, socket, model string, mode BackendMode) error
74+
Run(ctx context.Context, socket, model string, mode BackendMode, config *BackendConfiguration) error
7075
// Status returns a description of the backend's state.
7176
Status() string
7277
// GetDiskUsage returns the disk usage of the backend.

pkg/inference/backends/llamacpp/llamacpp.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"os/exec"
1212
"path/filepath"
1313
"runtime"
14+
"strconv"
1415
"strings"
1516

1617
"github.com/docker/model-runner/pkg/diskusage"
@@ -120,7 +121,7 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error {
120121
}
121122

122123
// Run implements inference.Backend.Run.
123-
func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode) error {
124+
func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
124125
modelPath, err := l.modelManager.GetModelPath(model)
125126
l.log.Infof("Model path: %s", modelPath)
126127
if err != nil {
@@ -138,6 +139,14 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
138139
}
139140

140141
args := l.config.GetArgs(modelPath, socket, mode)
142+
143+
if config != nil {
144+
if config.ContextSize >= 0 {
145+
args = append(args, "--ctx-size", strconv.Itoa(int(config.ContextSize)))
146+
}
147+
args = append(args, config.RawFlags...)
148+
}
149+
141150
l.log.Infof("llamaCppArgs: %v", args)
142151
llamaCppProcess := exec.CommandContext(
143152
ctx,

pkg/inference/backends/mlx/mlx.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func (m *mlx) Install(ctx context.Context, httpClient *http.Client) error {
4949
}
5050

5151
// Run implements inference.Backend.Run.
52-
func (m *mlx) Run(ctx context.Context, socket, model string, mode inference.BackendMode) error {
52+
func (m *mlx) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
5353
// TODO: Implement.
5454
m.log.Warn("MLX backend is not yet supported")
5555
return errors.New("not implemented")

pkg/inference/backends/vllm/vllm.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func (v *vLLM) Install(ctx context.Context, httpClient *http.Client) error {
4949
}
5050

5151
// Run implements inference.Backend.Run.
52-
func (v *vLLM) Run(ctx context.Context, socket, model string, mode inference.BackendMode) error {
52+
func (v *vLLM) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
5353
// TODO: Implement.
5454
v.log.Warn("vLLM backend is not yet supported")
5555
return errors.New("not implemented")

pkg/inference/scheduling/api.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,10 @@ type UnloadRequest struct {
7373
type UnloadResponse struct {
7474
UnloadedRunners int `json:"unloaded_runners"`
7575
}
76+
77+
// ConfigureRequest specifies per-model runtime configuration options.
78+
type ConfigureRequest struct {
79+
Model string `json:"model"`
80+
ContextSize int64 `json:"context-size,omitempty"`
81+
RawRuntimeFlags string `json:"raw-runtime-flags,omitempty"`
82+
}

pkg/inference/scheduling/loader.go

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ var (
2929
// errModelTooBig indicates that the model is too big to ever load into the
3030
// available system memory.
3131
errModelTooBig = errors.New("model too big")
32+
// errRunnerAlreadyActive indicates that a given runner is already active
33+
// and therefore can't be reconfigured for example
34+
errRunnerAlreadyActive = errors.New("runner already active")
3235
)
3336

3437
// runnerKey is used to index runners.
@@ -82,6 +85,8 @@ type loader struct {
8285
// timestamps maps slot indices to last usage times. Values in this slice
8386
// are only valid if the corresponding reference count is zero.
8487
timestamps []time.Time
88+
// runnerConfigs maps model names to runner configurations
89+
runnerConfigs map[runnerKey]inference.BackendConfiguration
8590
}
8691

8792
// newLoader creates a new loader.
@@ -122,6 +127,7 @@ func newLoader(
122127
references: make([]uint, nSlots),
123128
allocations: make([]uint64, nSlots),
124129
timestamps: make([]time.Time, nSlots),
130+
runnerConfigs: make(map[runnerKey]inference.BackendConfiguration),
125131
}
126132
l.guard <- struct{}{}
127133
return l
@@ -214,9 +220,11 @@ func (l *loader) Unload(ctx context.Context, unload UnloadRequest) int {
214220

215221
return len(l.runners) - func() int {
216222
if unload.All {
223+
l.runnerConfigs = make(map[runnerKey]inference.BackendConfiguration)
217224
return l.evict(false)
218225
} else {
219226
for _, model := range unload.Models {
227+
delete(l.runnerConfigs, runnerKey{unload.Backend, model, inference.BackendModeCompletion})
220228
// Evict both, completion and embedding models. We should consider
221229
// accepting a mode parameter in unload requests.
222230
l.evictRunner(unload.Backend, model, inference.BackendModeCompletion)
@@ -413,9 +421,13 @@ func (l *loader) load(ctx context.Context, backendName, model string, mode infer
413421

414422
// If we've identified a slot, then we're ready to start a runner.
415423
if slot >= 0 {
424+
var runnerConfig *inference.BackendConfiguration
425+
if rc, ok := l.runnerConfigs[runnerKey{backendName, model, mode}]; ok {
426+
runnerConfig = &rc
427+
}
416428
// Create the runner.
417429
l.log.Infof("Loading %s backend runner with model %s in %s mode", backendName, model, mode)
418-
runner, err := run(l.log, backend, model, mode, slot)
430+
runner, err := run(l.log, backend, model, mode, slot, runnerConfig)
419431
if err != nil {
420432
l.log.Warnf("Unable to start %s backend runner with model %s in %s mode: %v",
421433
backendName, model, mode, err,
@@ -492,3 +504,18 @@ func (l *loader) release(runner *runner) {
492504
// Signal waiters.
493505
l.broadcast()
494506
}
507+
508+
func (l *loader) setRunnerConfig(ctx context.Context, backendName, model string, mode inference.BackendMode, runnerConfig inference.BackendConfiguration) error {
509+
l.lock(ctx)
510+
defer l.unlock()
511+
512+
runnerId := runnerKey{backendName, model, mode}
513+
514+
if _, ok := l.runners[runnerId]; ok {
515+
return errRunnerAlreadyActive
516+
}
517+
518+
l.log.Infof("Configuring %s runner for %s", backendName, model)
519+
l.runnerConfigs[runnerId] = runnerConfig
520+
return nil
521+
}

pkg/inference/scheduling/runner.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ func run(
7373
model string,
7474
mode inference.BackendMode,
7575
slot int,
76+
runnerConfig *inference.BackendConfiguration,
7677
) (*runner, error) {
7778
// Create a dialer / transport that target backend on the specified slot.
7879
socket, err := RunnerSocketPath(slot)
@@ -152,7 +153,7 @@ func run(
152153

153154
// Start the backend run loop.
154155
go func() {
155-
if err := backend.Run(runCtx, socket, model, mode); err != nil {
156+
if err := backend.Run(runCtx, socket, model, mode, runnerConfig); err != nil {
156157
log.Warnf("Backend %s running model %s exited with error: %v",
157158
backend.Name(), model, err,
158159
)

pkg/inference/scheduling/scheduler.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import (
1717
"github.com/docker/model-runner/pkg/inference/models"
1818
"github.com/docker/model-runner/pkg/logging"
1919
"github.com/docker/model-runner/pkg/metrics"
20+
"github.com/mattn/go-shellwords"
2021
"golang.org/x/sync/errgroup"
2122
)
2223

@@ -112,6 +113,8 @@ func (s *Scheduler) routeHandlers(allowedOrigins []string) map[string]http.Handl
112113
m["GET "+inference.InferencePrefix+"/ps"] = s.GetRunningBackends
113114
m["GET "+inference.InferencePrefix+"/df"] = s.GetDiskUsage
114115
m["POST "+inference.InferencePrefix+"/unload"] = s.Unload
116+
m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = s.Configure
117+
m["POST "+inference.InferencePrefix+"/_configure"] = s.Configure
115118
return m
116119
}
117120

@@ -347,6 +350,61 @@ func (s *Scheduler) Unload(w http.ResponseWriter, r *http.Request) {
347350
}
348351
}
349352

353+
func (s *Scheduler) Configure(w http.ResponseWriter, r *http.Request) {
354+
// Determine the requested backend and ensure that it's valid.
355+
var backend inference.Backend
356+
if b := r.PathValue("backend"); b == "" {
357+
backend = s.defaultBackend
358+
} else {
359+
backend = s.backends[b]
360+
}
361+
if backend == nil {
362+
http.Error(w, ErrBackendNotFound.Error(), http.StatusNotFound)
363+
return
364+
}
365+
366+
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
367+
if err != nil {
368+
if _, ok := err.(*http.MaxBytesError); ok {
369+
http.Error(w, "request too large", http.StatusBadRequest)
370+
} else {
371+
http.Error(w, "unknown error", http.StatusInternalServerError)
372+
}
373+
return
374+
}
375+
376+
configureRequest := ConfigureRequest{
377+
Model: "",
378+
ContextSize: -1,
379+
RawRuntimeFlags: "",
380+
}
381+
if err := json.Unmarshal(body, &configureRequest); err != nil {
382+
http.Error(w, "invalid request", http.StatusBadRequest)
383+
return
384+
}
385+
rawFlags, err := shellwords.Parse(configureRequest.RawRuntimeFlags)
386+
if err != nil {
387+
http.Error(w, "invalid request", http.StatusBadRequest)
388+
return
389+
}
390+
391+
var runnerConfig inference.BackendConfiguration
392+
runnerConfig.ContextSize = configureRequest.ContextSize
393+
runnerConfig.RawFlags = rawFlags
394+
395+
if err := s.loader.setRunnerConfig(r.Context(), backend.Name(), configureRequest.Model, inference.BackendModeCompletion, runnerConfig); err != nil {
396+
s.log.Warnf("Failed to configure %s runner for %s: %s", backend.Name(), configureRequest.Model, err)
397+
if err == errRunnerAlreadyActive {
398+
w.WriteHeader(http.StatusConflict)
399+
} else {
400+
w.WriteHeader(http.StatusInternalServerError)
401+
}
402+
return
403+
}
404+
405+
w.WriteHeader(http.StatusAccepted)
406+
}
407+
350408
// ServeHTTP implements net/http.Handler.ServeHTTP.
351409
func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
352410
s.lock.Lock()

0 commit comments

Comments
 (0)