Skip to content

Commit 74d69c8

Browse files
authored
Merge pull request #520 from doringeman/list-configs
feat(scheduler): add endpoint to retrieve model configurations
2 parents a4ef5d7 + 4830663 commit 74d69c8

File tree

9 files changed

+172
-5
lines changed

9 files changed

+172
-5
lines changed

cmd/cli/commands/configure.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ func newConfigureCmd() *cobra.Command {
1111
var flags ConfigureFlags
1212

1313
c := &cobra.Command{
14-
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL [-- <runtime-flags...>]",
15-
Short: "Configure runtime options for a model",
16-
Hidden: true,
14+
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL [-- <runtime-flags...>]",
15+
Aliases: []string{"config"},
16+
Short: "Manage model runtime configurations",
17+
Hidden: true,
1718
Args: func(cmd *cobra.Command, args []string) error {
1819
argsBeforeDash := cmd.ArgsLenAtDash()
1920
if argsBeforeDash == -1 {
@@ -48,5 +49,6 @@ func newConfigureCmd() *cobra.Command {
4849
}
4950

5051
flags.RegisterFlags(c)
52+
c.AddCommand(newConfigureShowCmd())
5153
return c
5254
}

cmd/cli/commands/configure_show.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package commands
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
7+
"github.com/docker/model-runner/cmd/cli/commands/completion"
8+
"github.com/spf13/cobra"
9+
)
10+
11+
func newConfigureShowCmd() *cobra.Command {
12+
c := &cobra.Command{
13+
Use: "show [MODEL]",
14+
Short: "Show model configurations",
15+
Args: cobra.MaximumNArgs(1),
16+
RunE: func(cmd *cobra.Command, args []string) error {
17+
var modelFilter string
18+
if len(args) > 0 {
19+
modelFilter = args[0]
20+
}
21+
configs, err := desktopClient.ShowConfigs(modelFilter)
22+
if err != nil {
23+
return err
24+
}
25+
jsonResult, err := json.MarshalIndent(configs, "", " ")
26+
if err != nil {
27+
return fmt.Errorf("failed to marshal configs to JSON: %w", err)
28+
}
29+
cmd.Println(string(jsonResult))
30+
return nil
31+
},
32+
ValidArgsFunction: completion.ModelNames(getDesktopClient, 1),
33+
}
34+
return c
35+
}

cmd/cli/desktop/desktop.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,35 @@ func (c *Client) Unload(req UnloadRequest) (UnloadResponse, error) {
670670
return unloadResp, nil
671671
}
672672

673+
func (c *Client) ShowConfigs(modelFilter string) ([]scheduling.ModelConfigEntry, error) {
674+
configureBackendPath := inference.InferencePrefix + "/_configure"
675+
if modelFilter != "" {
676+
configureBackendPath += "?model=" + url.QueryEscape(modelFilter)
677+
}
678+
resp, err := c.doRequest(http.MethodGet, configureBackendPath, nil)
679+
if err != nil {
680+
return nil, c.handleQueryError(err, configureBackendPath)
681+
}
682+
defer resp.Body.Close()
683+
684+
if resp.StatusCode != http.StatusOK {
685+
body, _ := io.ReadAll(resp.Body)
686+
return nil, fmt.Errorf("listing configs failed with status %s: %s", resp.Status, string(body))
687+
}
688+
689+
body, err := io.ReadAll(resp.Body)
690+
if err != nil {
691+
return nil, fmt.Errorf("failed to read response body: %w", err)
692+
}
693+
694+
var configs []scheduling.ModelConfigEntry
695+
if err := json.Unmarshal(body, &configs); err != nil {
696+
return nil, fmt.Errorf("failed to unmarshal response body: %w", err)
697+
}
698+
699+
return configs, nil
700+
}
701+
673702
func (c *Client) ConfigureBackend(request scheduling.ConfigureRequest) error {
674703
configureBackendPath := inference.InferencePrefix + "/_configure"
675704
jsonData, err := json.Marshal(request)

cmd/cli/docs/reference/docker_model_configure.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
command: docker model configure
2-
short: Configure runtime options for a model
3-
long: Configure runtime options for a model
2+
aliases: docker model configure, docker model config
3+
short: Manage model runtime configurations
4+
long: Manage model runtime configurations
45
usage: docker model configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL [-- <runtime-flags...>]
56
pname: docker model
67
plink: docker_model.yaml
8+
cname:
9+
- docker model configure show
10+
clink:
11+
- docker_model_configure_show.yaml
712
options:
813
- option: context-size
914
value_type: int32
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
command: docker model configure show
2+
short: Show model configurations
3+
long: Show model configurations
4+
usage: docker model configure show [MODEL]
5+
pname: docker model configure
6+
plink: docker_model_configure.yaml
7+
deprecated: false
8+
hidden: true
9+
experimental: false
10+
experimentalcli: false
11+
kubernetes: false
12+
swarm: false
13+

pkg/inference/backend.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package inference
22

33
import (
44
"context"
5+
"encoding/json"
6+
"fmt"
57
"net/http"
68
)
79

@@ -40,6 +42,25 @@ func (m BackendMode) String() string {
4042
}
4143
}
4244

45+
// MarshalJSON implements json.Marshaler for BackendMode.
46+
func (m BackendMode) MarshalJSON() ([]byte, error) {
47+
return []byte(`"` + m.String() + `"`), nil
48+
}
49+
50+
// UnmarshalJSON implements json.Unmarshaler for BackendMode.
51+
func (m *BackendMode) UnmarshalJSON(data []byte) error {
52+
var s string
53+
if err := json.Unmarshal(data, &s); err != nil {
54+
return err
55+
}
56+
mode, ok := ParseBackendMode(s)
57+
if !ok {
58+
return fmt.Errorf("unknown backend mode: %q", s)
59+
}
60+
*m = mode
61+
return nil
62+
}
63+
4364
// ParseBackendMode converts a string mode to BackendMode.
4465
// It returns the parsed mode and a boolean indicating if the mode was known.
4566
// For unknown modes, it returns BackendModeCompletion and false.

pkg/inference/scheduling/api.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,12 @@ type ConfigureRequest struct {
9898
RawRuntimeFlags string `json:"raw-runtime-flags,omitempty"`
9999
inference.BackendConfiguration
100100
}
101+
102+
// ModelConfigEntry represents a model configuration entry with its associated metadata.
103+
type ModelConfigEntry struct {
104+
Backend string
105+
Model string
106+
ModelID string
107+
Mode inference.BackendMode
108+
Config inference.BackendConfiguration
109+
}

pkg/inference/scheduling/http_handler.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc {
8484
m["POST "+inference.InferencePrefix+"/unload"] = h.Unload
8585
m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = h.Configure
8686
m["POST "+inference.InferencePrefix+"/_configure"] = h.Configure
87+
m["GET "+inference.InferencePrefix+"/_configure"] = h.GetModelConfigs
8788
m["GET "+inference.InferencePrefix+"/requests"] = h.scheduler.openAIRecorder.GetRecordsHandler()
8889
return m
8990
}
@@ -350,6 +351,31 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) {
350351
w.WriteHeader(http.StatusAccepted)
351352
}
352353

354+
// GetModelConfigs returns model configurations. If a model is specified in the query parameter,
355+
// returns only configs for that model; otherwise returns all configs.
356+
func (h *HTTPHandler) GetModelConfigs(w http.ResponseWriter, r *http.Request) {
357+
model := r.URL.Query().Get("model")
358+
359+
configs := h.scheduler.loader.getAllRunnerConfigs(r.Context())
360+
361+
if model != "" {
362+
modelID := h.scheduler.modelManager.ResolveID(model)
363+
filtered := configs[:0]
364+
for _, entry := range configs {
365+
if entry.ModelID == modelID {
366+
filtered = append(filtered, entry)
367+
}
368+
}
369+
configs = filtered
370+
}
371+
372+
w.Header().Set("Content-Type", "application/json")
373+
if err := json.NewEncoder(w).Encode(configs); err != nil {
374+
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
375+
return
376+
}
377+
}
378+
353379
// ServeHTTP implements net/http.Handler.ServeHTTP.
354380
func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
355381
h.lock.RLock()

pkg/inference/scheduling/loader.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,3 +626,30 @@ func (l *loader) setRunnerConfig(ctx context.Context, backendName, modelID strin
626626
l.runnerConfigs[configKey] = runnerConfig
627627
return nil
628628
}
629+
630+
// getAllRunnerConfigs retrieves all runner configurations.
631+
func (l *loader) getAllRunnerConfigs(ctx context.Context) []ModelConfigEntry {
632+
if !l.lock(ctx) {
633+
return nil
634+
}
635+
defer l.unlock()
636+
637+
entries := make([]ModelConfigEntry, 0, len(l.runnerConfigs))
638+
for key, config := range l.runnerConfigs {
639+
model, err := l.modelManager.GetLocal(key.modelID)
640+
if err == nil {
641+
modelName := ""
642+
if len(model.Tags()) > 0 {
643+
modelName = model.Tags()[0]
644+
}
645+
entries = append(entries, ModelConfigEntry{
646+
Backend: key.backend,
647+
Model: modelName,
648+
ModelID: key.modelID,
649+
Mode: key.mode,
650+
Config: config,
651+
})
652+
}
653+
}
654+
return entries
655+
}

0 commit comments

Comments
 (0)