Skip to content

Commit 5bcf3e5

Browse files
author
Piotr Stankiewicz
committed
WiP: Support runner configuration
Signed-off-by: Piotr Stankiewicz <[email protected]>
1 parent e3916bc commit 5bcf3e5

File tree

8 files changed

+84
-6
lines changed

8 files changed

+84
-6
lines changed

pkg/inference/backend.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ func (m BackendMode) String() string {
2929
}
3030
}
3131

32+
type BackendConfiguration struct {
33+
ContextSize uint64
34+
RawFlags string
35+
}
36+
3237
// Backend is the interface implemented by inference engine backends. Backend
3338
// implementations need not be safe for concurrent invocation of the following
3439
// methods, though their underlying server implementations do need to support
@@ -66,7 +71,7 @@ type Backend interface {
6671
// to be loaded. Backends should not load multiple models at once and should
6772
// instead load only the specified model. Backends should still respond to
6873
// OpenAI API requests for other models with a 421 error code.
69-
Run(ctx context.Context, socket, model string, mode BackendMode) error
74+
Run(ctx context.Context, socket, model string, mode BackendMode, config *BackendConfiguration) error
7075
// Status returns a description of the backend's state.
7176
Status() string
7277
// GetDiskUsage returns the disk usage of the backend.

pkg/inference/backends/llamacpp/llamacpp.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"os/exec"
1212
"path/filepath"
1313
"runtime"
14+
"strconv"
1415
"strings"
1516

1617
"github.com/docker/model-runner/pkg/diskusage"
@@ -120,7 +121,7 @@ func (l *llamaCpp) Install(ctx context.Context, httpClient *http.Client) error {
120121
}
121122

122123
// Run implements inference.Backend.Run.
123-
func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode) error {
124+
func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
124125
modelPath, err := l.modelManager.GetModelPath(model)
125126
l.log.Infof("Model path: %s", modelPath)
126127
if err != nil {
@@ -138,6 +139,13 @@ func (l *llamaCpp) Run(ctx context.Context, socket, model string, mode inference
138139
}
139140

140141
args := l.config.GetArgs(modelPath, socket, mode)
142+
143+
if config != nil {
144+
args = append(args, "--ctx-size", strconv.Itoa(int(config.ContextSize)))
145+
// FIXME(p1-0tr): this needs to be parsed, to respect quoted values etc.
146+
args = append(args, strings.Split(config.RawFlags, " ")...)
147+
}
148+
141149
l.log.Infof("llamaCppArgs: %v", args)
142150
llamaCppProcess := exec.CommandContext(
143151
ctx,

pkg/inference/backends/mlx/mlx.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func (m *mlx) Install(ctx context.Context, httpClient *http.Client) error {
4949
}
5050

5151
// Run implements inference.Backend.Run.
52-
func (m *mlx) Run(ctx context.Context, socket, model string, mode inference.BackendMode) error {
52+
func (m *mlx) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
5353
// TODO: Implement.
5454
m.log.Warn("MLX backend is not yet supported")
5555
return errors.New("not implemented")

pkg/inference/backends/vllm/vllm.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func (v *vLLM) Install(ctx context.Context, httpClient *http.Client) error {
4949
}
5050

5151
// Run implements inference.Backend.Run.
52-
func (v *vLLM) Run(ctx context.Context, socket, model string, mode inference.BackendMode) error {
52+
func (v *vLLM) Run(ctx context.Context, socket, model string, mode inference.BackendMode, config *inference.BackendConfiguration) error {
5353
// TODO: Implement.
5454
v.log.Warn("vLLM backend is not yet supported")
5555
return errors.New("not implemented")

pkg/inference/scheduling/api.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,10 @@ type UnloadRequest struct {
7373
type UnloadResponse struct {
7474
UnloadedRunners int `json:"unloaded_runners"`
7575
}
76+
77+
// ConfigureRequest specifies per-model runtime configuration options.
78+
type ConfigureRequest struct {
79+
Model string `json:"model"`
80+
ContextSize uint64 `json:"context-size"`
81+
RawRuntimeFlags string `json:"raw-runtime-flags"`
82+
}

pkg/inference/scheduling/loader.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ type loader struct {
8282
// timestamps maps slot indices to last usage times. Values in this slice
8383
// are only valid if the corresponding reference count is zero.
8484
timestamps []time.Time
85+
// runnerConfigs maps model names to runner configurations
86+
runnerConfigs map[runnerKey]inference.BackendConfiguration
8587
}
8688

8789
// newLoader creates a new loader.
@@ -122,6 +124,7 @@ func newLoader(
122124
references: make([]uint, nSlots),
123125
allocations: make([]uint64, nSlots),
124126
timestamps: make([]time.Time, nSlots),
127+
runnerConfigs: make(map[runnerKey]inference.BackendConfiguration),
125128
}
126129
l.guard <- struct{}{}
127130
return l
@@ -214,9 +217,11 @@ func (l *loader) Unload(ctx context.Context, unload UnloadRequest) int {
214217

215218
return len(l.runners) - func() int {
216219
if unload.All {
220+
l.runnerConfigs = make(map[runnerKey]inference.BackendConfiguration)
217221
return l.evict(false)
218222
} else {
219223
for _, model := range unload.Models {
224+
delete(l.runnerConfigs, runnerKey{unload.Backend, model, inference.BackendModeCompletion})
220225
// Evict both, completion and embedding models. We should consider
221226
// accepting a mode parameter in unload requests.
222227
l.evictRunner(unload.Backend, model, inference.BackendModeCompletion)
@@ -413,9 +418,13 @@ func (l *loader) load(ctx context.Context, backendName, model string, mode infer
413418

414419
// If we've identified a slot, then we're ready to start a runner.
415420
if slot >= 0 {
421+
var runnerConfig *inference.BackendConfiguration
422+
if rc, ok := l.runnerConfigs[runnerKey{backendName, model, mode}]; ok {
423+
runnerConfig = &rc
424+
}
416425
// Create the runner.
417426
l.log.Infof("Loading %s backend runner with model %s in %s mode", backendName, model, mode)
418-
runner, err := run(l.log, backend, model, mode, slot)
427+
runner, err := run(l.log, backend, model, mode, slot, runnerConfig)
419428
if err != nil {
420429
l.log.Warnf("Unable to start %s backend runner with model %s in %s mode: %v",
421430
backendName, model, mode, err,
@@ -492,3 +501,11 @@ func (l *loader) release(runner *runner) {
492501
// Signal waiters.
493502
l.broadcast()
494503
}
504+
505+
func (l *loader) setRunnerConfig(ctx context.Context, backendName, model string, mode inference.BackendMode, runnerConfig inference.BackendConfiguration) {
506+
l.lock(ctx)
507+
defer l.unlock()
508+
509+
l.log.Infof("Configuring %s runner for %s", backendName, model)
510+
l.runnerConfigs[runnerKey{backendName, model, mode}] = runnerConfig
511+
}

pkg/inference/scheduling/runner.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ func run(
7373
model string,
7474
mode inference.BackendMode,
7575
slot int,
76+
runnerConfig *inference.BackendConfiguration,
7677
) (*runner, error) {
7778
// Create a dialer / transport that target backend on the specified slot.
7879
socket, err := RunnerSocketPath(slot)
@@ -152,7 +153,7 @@ func run(
152153

153154
// Start the backend run loop.
154155
go func() {
155-
if err := backend.Run(runCtx, socket, model, mode); err != nil {
156+
if err := backend.Run(runCtx, socket, model, mode, runnerConfig); err != nil {
156157
log.Warnf("Backend %s running model %s exited with error: %v",
157158
backend.Name(), model, err,
158159
)

pkg/inference/scheduling/scheduler.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ func (s *Scheduler) routeHandlers(allowedOrigins []string) map[string]http.Handl
112112
m["GET "+inference.InferencePrefix+"/ps"] = s.GetRunningBackends
113113
m["GET "+inference.InferencePrefix+"/df"] = s.GetDiskUsage
114114
m["POST "+inference.InferencePrefix+"/unload"] = s.Unload
115+
m["POST "+inference.InferencePrefix+"/{backend}/configure"] = s.Configure
116+
m["POST "+inference.InferencePrefix+"/configure"] = s.Configure
115117
return m
116118
}
117119

@@ -347,6 +349,44 @@ func (s *Scheduler) Unload(w http.ResponseWriter, r *http.Request) {
347349
}
348350
}
349351

352+
func (s *Scheduler) Configure(w http.ResponseWriter, r *http.Request) {
353+
// Determine the requested backend and ensure that it's valid.
354+
var backend inference.Backend
355+
if b := r.PathValue("backend"); b == "" {
356+
backend = s.defaultBackend
357+
} else {
358+
backend = s.backends[b]
359+
}
360+
if backend == nil {
361+
http.Error(w, ErrBackendNotFound.Error(), http.StatusNotFound)
362+
return
363+
}
364+
365+
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maximumOpenAIInferenceRequestSize))
366+
if err != nil {
367+
if _, ok := err.(*http.MaxBytesError); ok {
368+
http.Error(w, "request too large", http.StatusBadRequest)
369+
} else {
370+
http.Error(w, "unknown error", http.StatusInternalServerError)
371+
}
372+
return
373+
}
374+
375+
var configureRequest ConfigureRequest
376+
if err := json.Unmarshal(body, &configureRequest); err != nil {
377+
http.Error(w, "invalid request", http.StatusBadRequest)
378+
return
379+
}
380+
381+
var runnerConfig inference.BackendConfiguration
382+
runnerConfig.ContextSize = configureRequest.ContextSize
383+
runnerConfig.RawFlags = configureRequest.RawRuntimeFlags
384+
385+
s.loader.setRunnerConfig(r.Context(), backend.Name(), configureRequest.Model, inference.BackendModeCompletion, runnerConfig)
386+
387+
w.WriteHeader(http.StatusOK)
388+
}
389+
350390
// ServeHTTP implements net/http.Handler.ServeHTTP.
351391
func (s *Scheduler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
352392
s.lock.Lock()

0 commit comments

Comments
 (0)