diff --git a/cmd/cli/commands/configure.go b/cmd/cli/commands/configure.go index 9679e3b0..22f9e52c 100644 --- a/cmd/cli/commands/configure.go +++ b/cmd/cli/commands/configure.go @@ -11,9 +11,10 @@ func newConfigureCmd() *cobra.Command { var flags ConfigureFlags c := &cobra.Command{ - Use: "configure [--context-size=] [--speculative-draft-model=] [--hf_overrides=] [--gpu-memory-utilization=] [--mode=] [--think] MODEL [-- ]", - Short: "Configure runtime options for a model", - Hidden: true, + Use: "configure [--context-size=] [--speculative-draft-model=] [--hf_overrides=] [--gpu-memory-utilization=] [--mode=] [--think] MODEL [-- ]", + Aliases: []string{"config"}, + Short: "Manage model runtime configurations", + Hidden: true, Args: func(cmd *cobra.Command, args []string) error { argsBeforeDash := cmd.ArgsLenAtDash() if argsBeforeDash == -1 { @@ -48,5 +49,6 @@ func newConfigureCmd() *cobra.Command { } flags.RegisterFlags(c) + c.AddCommand(newConfigureShowCmd()) return c } diff --git a/cmd/cli/commands/configure_show.go b/cmd/cli/commands/configure_show.go new file mode 100644 index 00000000..15864c89 --- /dev/null +++ b/cmd/cli/commands/configure_show.go @@ -0,0 +1,35 @@ +package commands + +import ( + "encoding/json" + "fmt" + + "github.com/docker/model-runner/cmd/cli/commands/completion" + "github.com/spf13/cobra" +) + +func newConfigureShowCmd() *cobra.Command { + c := &cobra.Command{ + Use: "show [MODEL]", + Short: "Show model configurations", + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + var modelFilter string + if len(args) > 0 { + modelFilter = args[0] + } + configs, err := desktopClient.ShowConfigs(modelFilter) + if err != nil { + return err + } + jsonResult, err := json.MarshalIndent(configs, "", " ") + if err != nil { + return fmt.Errorf("failed to marshal configs to JSON: %w", err) + } + cmd.Println(string(jsonResult)) + return nil + }, + ValidArgsFunction: completion.ModelNames(getDesktopClient, 1), + } + return c +} diff --git a/cmd/cli/desktop/desktop.go b/cmd/cli/desktop/desktop.go index 05e1d330..11facc42 100644 --- a/cmd/cli/desktop/desktop.go +++ b/cmd/cli/desktop/desktop.go @@ -670,6 +670,35 @@ func (c *Client) Unload(req UnloadRequest) (UnloadResponse, error) { return unloadResp, nil } +func (c *Client) ShowConfigs(modelFilter string) ([]scheduling.ModelConfigEntry, error) { + configureBackendPath := inference.InferencePrefix + "/_configure" + if modelFilter != "" { + configureBackendPath += "?model=" + url.QueryEscape(modelFilter) + } + resp, err := c.doRequest(http.MethodGet, configureBackendPath, nil) + if err != nil { + return nil, c.handleQueryError(err, configureBackendPath) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("listing configs failed with status %s: %s", resp.Status, string(body)) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + var configs []scheduling.ModelConfigEntry + if err := json.Unmarshal(body, &configs); err != nil { + return nil, fmt.Errorf("failed to unmarshal response body: %w", err) + } + + return configs, nil +} + func (c *Client) ConfigureBackend(request scheduling.ConfigureRequest) error { configureBackendPath := inference.InferencePrefix + "/_configure" jsonData, err := json.Marshal(request) diff --git a/cmd/cli/docs/reference/docker_model_configure.yaml b/cmd/cli/docs/reference/docker_model_configure.yaml index 9a9d3e8c..ce7ac015 100644 --- a/cmd/cli/docs/reference/docker_model_configure.yaml +++ b/cmd/cli/docs/reference/docker_model_configure.yaml @@ -1,9 +1,14 @@ command: docker model configure -short: Configure runtime options for a model -long: Configure runtime options for a model +aliases: docker model configure, docker model config +short: Manage model runtime configurations +long: Manage model runtime configurations usage: docker model configure [--context-size=] [--speculative-draft-model=] [--hf_overrides=] [--gpu-memory-utilization=] [--mode=] [--think] MODEL [-- ] pname: docker model plink: docker_model.yaml +cname: + - docker model configure show +clink: + - docker_model_configure_show.yaml options: - option: context-size value_type: int32 diff --git a/cmd/cli/docs/reference/docker_model_configure_show.yaml b/cmd/cli/docs/reference/docker_model_configure_show.yaml new file mode 100644 index 00000000..588c5c1b --- /dev/null +++ b/cmd/cli/docs/reference/docker_model_configure_show.yaml @@ -0,0 +1,13 @@ +command: docker model configure show +short: Show model configurations +long: Show model configurations +usage: docker model configure show [MODEL] +pname: docker model configure +plink: docker_model_configure.yaml +deprecated: false +hidden: true +experimental: false +experimentalcli: false +kubernetes: false +swarm: false + diff --git a/pkg/inference/backend.go b/pkg/inference/backend.go index 2ea64d87..36b7580a 100644 --- a/pkg/inference/backend.go +++ b/pkg/inference/backend.go @@ -2,6 +2,8 @@ package inference import ( "context" + "encoding/json" + "fmt" "net/http" ) @@ -40,6 +42,25 @@ func (m BackendMode) String() string { } } +// MarshalJSON implements json.Marshaler for BackendMode. +func (m BackendMode) MarshalJSON() ([]byte, error) { + return []byte(`"` + m.String() + `"`), nil +} + +// UnmarshalJSON implements json.Unmarshaler for BackendMode. +func (m *BackendMode) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return err + } + mode, ok := ParseBackendMode(s) + if !ok { + return fmt.Errorf("unknown backend mode: %q", s) + } + *m = mode + return nil +} + // ParseBackendMode converts a string mode to BackendMode. // It returns the parsed mode and a boolean indicating if the mode was known. // For unknown modes, it returns BackendModeCompletion and false. diff --git a/pkg/inference/scheduling/api.go b/pkg/inference/scheduling/api.go index e899daa4..e66a28ff 100644 --- a/pkg/inference/scheduling/api.go +++ b/pkg/inference/scheduling/api.go @@ -98,3 +98,12 @@ type ConfigureRequest struct { RawRuntimeFlags string `json:"raw-runtime-flags,omitempty"` inference.BackendConfiguration } + +// ModelConfigEntry represents a model configuration entry with its associated metadata. +type ModelConfigEntry struct { + Backend string + Model string + ModelID string + Mode inference.BackendMode + Config inference.BackendConfiguration +} diff --git a/pkg/inference/scheduling/http_handler.go b/pkg/inference/scheduling/http_handler.go index 67d5cd50..4d2ccb43 100644 --- a/pkg/inference/scheduling/http_handler.go +++ b/pkg/inference/scheduling/http_handler.go @@ -84,6 +84,7 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc { m["POST "+inference.InferencePrefix+"/unload"] = h.Unload m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = h.Configure m["POST "+inference.InferencePrefix+"/_configure"] = h.Configure + m["GET "+inference.InferencePrefix+"/_configure"] = h.GetModelConfigs m["GET "+inference.InferencePrefix+"/requests"] = h.scheduler.openAIRecorder.GetRecordsHandler() return m } @@ -350,6 +351,31 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusAccepted) } +// GetModelConfigs returns model configurations. If a model is specified in the query parameter, +// returns only configs for that model; otherwise returns all configs. +func (h *HTTPHandler) GetModelConfigs(w http.ResponseWriter, r *http.Request) { + model := r.URL.Query().Get("model") + + configs := h.scheduler.loader.getAllRunnerConfigs(r.Context()) + + if model != "" { + modelID := h.scheduler.modelManager.ResolveID(model) + filtered := configs[:0] + for _, entry := range configs { + if entry.ModelID == modelID { + filtered = append(filtered, entry) + } + } + configs = filtered + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(configs); err != nil { + http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError) + return + } +} + // ServeHTTP implements net/http.Handler.ServeHTTP. func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.lock.RLock() diff --git a/pkg/inference/scheduling/loader.go b/pkg/inference/scheduling/loader.go index 3d6c4d3b..a4762353 100644 --- a/pkg/inference/scheduling/loader.go +++ b/pkg/inference/scheduling/loader.go @@ -626,3 +626,30 @@ func (l *loader) setRunnerConfig(ctx context.Context, backendName, modelID strin l.runnerConfigs[configKey] = runnerConfig return nil } + +// getAllRunnerConfigs retrieves all runner configurations. +func (l *loader) getAllRunnerConfigs(ctx context.Context) []ModelConfigEntry { + if !l.lock(ctx) { + return nil + } + defer l.unlock() + + entries := make([]ModelConfigEntry, 0, len(l.runnerConfigs)) + for key, config := range l.runnerConfigs { + model, err := l.modelManager.GetLocal(key.modelID) + if err == nil { + modelName := "" + if len(model.Tags()) > 0 { + modelName = model.Tags()[0] + } + entries = append(entries, ModelConfigEntry{ + Backend: key.backend, + Model: modelName, + ModelID: key.modelID, + Mode: key.mode, + Config: config, + }) + } + } + return entries +}