Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions cmd/cli/commands/configure.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ func newConfigureCmd() *cobra.Command {
var flags ConfigureFlags

c := &cobra.Command{
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL [-- <runtime-flags...>]",
Short: "Configure runtime options for a model",
Hidden: true,
Use: "configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL [-- <runtime-flags...>]",
Aliases: []string{"config"},
Short: "Manage model runtime configurations",
Hidden: true,
Args: func(cmd *cobra.Command, args []string) error {
argsBeforeDash := cmd.ArgsLenAtDash()
if argsBeforeDash == -1 {
Expand Down Expand Up @@ -48,5 +49,6 @@ func newConfigureCmd() *cobra.Command {
}

flags.RegisterFlags(c)
c.AddCommand(newConfigureShowCmd())
return c
}
35 changes: 35 additions & 0 deletions cmd/cli/commands/configure_show.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package commands

import (
"encoding/json"
"fmt"

"github.com/docker/model-runner/cmd/cli/commands/completion"
"github.com/spf13/cobra"
)

func newConfigureShowCmd() *cobra.Command {
c := &cobra.Command{
Use: "show [MODEL]",
Short: "Show model configurations",
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
var modelFilter string
if len(args) > 0 {
modelFilter = args[0]
}
configs, err := desktopClient.ShowConfigs(modelFilter)
if err != nil {
return err
}
jsonResult, err := json.MarshalIndent(configs, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal configs to JSON: %w", err)
}
cmd.Println(string(jsonResult))
return nil
},
ValidArgsFunction: completion.ModelNames(getDesktopClient, 1),
}
return c
}
29 changes: 29 additions & 0 deletions cmd/cli/desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,35 @@ func (c *Client) Unload(req UnloadRequest) (UnloadResponse, error) {
return unloadResp, nil
}

func (c *Client) ShowConfigs(modelFilter string) ([]scheduling.ModelConfigEntry, error) {
configureBackendPath := inference.InferencePrefix + "/_configure"
if modelFilter != "" {
configureBackendPath += "?model=" + url.QueryEscape(modelFilter)
}
resp, err := c.doRequest(http.MethodGet, configureBackendPath, nil)
if err != nil {
return nil, c.handleQueryError(err, configureBackendPath)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("listing configs failed with status %s: %s", resp.Status, string(body))
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}

var configs []scheduling.ModelConfigEntry
if err := json.Unmarshal(body, &configs); err != nil {
return nil, fmt.Errorf("failed to unmarshal response body: %w", err)
}

return configs, nil
}

func (c *Client) ConfigureBackend(request scheduling.ConfigureRequest) error {
configureBackendPath := inference.InferencePrefix + "/_configure"
jsonData, err := json.Marshal(request)
Expand Down
9 changes: 7 additions & 2 deletions cmd/cli/docs/reference/docker_model_configure.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
command: docker model configure
short: Configure runtime options for a model
long: Configure runtime options for a model
aliases: docker model configure, docker model config
short: Manage model runtime configurations
long: Manage model runtime configurations
usage: docker model configure [--context-size=<n>] [--speculative-draft-model=<model>] [--hf_overrides=<json>] [--gpu-memory-utilization=<float>] [--mode=<mode>] [--think] MODEL [-- <runtime-flags...>]
pname: docker model
plink: docker_model.yaml
cname:
- docker model configure show
clink:
- docker_model_configure_show.yaml
options:
- option: context-size
value_type: int32
Expand Down
13 changes: 13 additions & 0 deletions cmd/cli/docs/reference/docker_model_configure_show.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
command: docker model configure show
short: Show model configurations
long: Show model configurations
usage: docker model configure show [MODEL]
pname: docker model configure
plink: docker_model_configure.yaml
deprecated: false
hidden: true
experimental: false
experimentalcli: false
kubernetes: false
swarm: false

21 changes: 21 additions & 0 deletions pkg/inference/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package inference

import (
"context"
"encoding/json"
"fmt"
"net/http"
)

Expand Down Expand Up @@ -40,6 +42,25 @@ func (m BackendMode) String() string {
}
}

// MarshalJSON implements json.Marshaler for BackendMode.
func (m BackendMode) MarshalJSON() ([]byte, error) {
return []byte(`"` + m.String() + `"`), nil
}

// UnmarshalJSON implements json.Unmarshaler for BackendMode.
func (m *BackendMode) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return err
}
mode, ok := ParseBackendMode(s)
if !ok {
return fmt.Errorf("unknown backend mode: %q", s)
}
*m = mode
return nil
}

// ParseBackendMode converts a string mode to BackendMode.
// It returns the parsed mode and a boolean indicating if the mode was known.
// For unknown modes, it returns BackendModeCompletion and false.
Expand Down
9 changes: 9 additions & 0 deletions pkg/inference/scheduling/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,12 @@ type ConfigureRequest struct {
RawRuntimeFlags string `json:"raw-runtime-flags,omitempty"`
inference.BackendConfiguration
}

// ModelConfigEntry represents a model configuration entry with its associated metadata.
type ModelConfigEntry struct {
Backend string
Model string
ModelID string
Mode inference.BackendMode
Config inference.BackendConfiguration
}
26 changes: 26 additions & 0 deletions pkg/inference/scheduling/http_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func (h *HTTPHandler) routeHandlers() map[string]http.HandlerFunc {
m["POST "+inference.InferencePrefix+"/unload"] = h.Unload
m["POST "+inference.InferencePrefix+"/{backend}/_configure"] = h.Configure
m["POST "+inference.InferencePrefix+"/_configure"] = h.Configure
m["GET "+inference.InferencePrefix+"/_configure"] = h.GetModelConfigs
m["GET "+inference.InferencePrefix+"/requests"] = h.scheduler.openAIRecorder.GetRecordsHandler()
return m
}
Expand Down Expand Up @@ -350,6 +351,31 @@ func (h *HTTPHandler) Configure(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusAccepted)
}

// GetModelConfigs returns model configurations. If a model is specified in the query parameter,
// returns only configs for that model; otherwise returns all configs.
func (h *HTTPHandler) GetModelConfigs(w http.ResponseWriter, r *http.Request) {
model := r.URL.Query().Get("model")

configs := h.scheduler.loader.getAllRunnerConfigs(r.Context())

if model != "" {
modelID := h.scheduler.modelManager.ResolveID(model)
filtered := configs[:0]
for _, entry := range configs {
if entry.ModelID == modelID {
filtered = append(filtered, entry)
}
}
configs = filtered
}

w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(configs); err != nil {
http.Error(w, fmt.Sprintf("Failed to encode response: %v", err), http.StatusInternalServerError)
return
}
}

// ServeHTTP implements net/http.Handler.ServeHTTP.
func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.lock.RLock()
Expand Down
27 changes: 27 additions & 0 deletions pkg/inference/scheduling/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -626,3 +626,30 @@ func (l *loader) setRunnerConfig(ctx context.Context, backendName, modelID strin
l.runnerConfigs[configKey] = runnerConfig
return nil
}

// getAllRunnerConfigs retrieves all runner configurations.
func (l *loader) getAllRunnerConfigs(ctx context.Context) []ModelConfigEntry {
if !l.lock(ctx) {
return nil
}
defer l.unlock()

entries := make([]ModelConfigEntry, 0, len(l.runnerConfigs))
for key, config := range l.runnerConfigs {
model, err := l.modelManager.GetLocal(key.modelID)
if err == nil {
modelName := ""
if len(model.Tags()) > 0 {
modelName = model.Tags()[0]
}
entries = append(entries, ModelConfigEntry{
Backend: key.backend,
Model: modelName,
ModelID: key.modelID,
Mode: key.mode,
Config: config,
})
}
}
return entries
}
Loading