Skip to content
This repository was archived by the owner on Oct 6, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions commands/backend.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package commands

import (
"errors"
"fmt"
"maps"
"os"
"slices"
"strings"
)

// ValidBackends is a map of valid backends
var ValidBackends = map[string]bool{
"llama.cpp": true,
"openai": true,
}

// validateBackend checks if the provided backend is valid
func validateBackend(backend string) error {
if !ValidBackends[backend] {
return fmt.Errorf("invalid backend '%s'. Valid backends are: %s",
backend, ValidBackendsKeys())
}
return nil
}

// ensureAPIKey retrieves the API key if needed
func ensureAPIKey(backend string) (string, error) {
if backend == "openai" {
apiKey := os.Getenv("OPENAI_API_KEY")
if apiKey == "" {
return "", errors.New("OPENAI_API_KEY environment variable is required when using --backend=openai")
}
return apiKey, nil
}
return "", nil
}

func ValidBackendsKeys() string {
return strings.Join(slices.Collect(maps.Keys(ValidBackends)), ", ")
}
30 changes: 23 additions & 7 deletions commands/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,39 @@ import (

func newListCmd() *cobra.Command {
var jsonFormat, openai, quiet bool
var backend string
c := &cobra.Command{
Use: "list [OPTIONS]",
Aliases: []string{"ls"},
Short: "List the models pulled to your local environment",
RunE: func(cmd *cobra.Command, args []string) error {
if openai && quiet {
return fmt.Errorf("--quiet flag cannot be used with --openai flag")
// Validate backend if specified
if backend != "" {
if err := validateBackend(backend); err != nil {
return err
}
}

if (backend == "openai" || openai) && quiet {
return fmt.Errorf("--quiet flag cannot be used with --openai flag or OpenAI backend")
}

// Validate API key for OpenAI backend
apiKey, err := ensureAPIKey(backend)
if err != nil {
return err
}

// If we're doing an automatic install, only show the installation
// status if it won't corrupt machine-readable output.
var standaloneInstallPrinter standalone.StatusPrinter
if !jsonFormat && !openai && !quiet {
if !jsonFormat && !openai && !quiet && backend == "" {
standaloneInstallPrinter = cmd
}
if _, err := ensureStandaloneRunnerAvailable(cmd.Context(), standaloneInstallPrinter); err != nil {
return fmt.Errorf("unable to initialize standalone model runner: %w", err)
}
models, err := listModels(openai, desktopClient, quiet, jsonFormat)
models, err := listModels(openai, backend, desktopClient, quiet, jsonFormat, apiKey)
if err != nil {
return err
}
Expand All @@ -47,12 +62,13 @@ func newListCmd() *cobra.Command {
c.Flags().BoolVar(&jsonFormat, "json", false, "List models in a JSON format")
c.Flags().BoolVar(&openai, "openai", false, "List models in an OpenAI format")
c.Flags().BoolVarP(&quiet, "quiet", "q", false, "Only show model IDs")
c.Flags().StringVar(&backend, "backend", "", fmt.Sprintf("Specify the backend to use (%s)", ValidBackendsKeys()))
return c
}

func listModels(openai bool, desktopClient *desktop.Client, quiet bool, jsonFormat bool) (string, error) {
if openai {
models, err := desktopClient.ListOpenAI()
func listModels(openai bool, backend string, desktopClient *desktop.Client, quiet bool, jsonFormat bool, apiKey string) (string, error) {
if openai || backend == "openai" {
models, err := desktopClient.ListOpenAI(backend, apiKey)
if err != nil {
err = handleClientError(err, "Failed to list models")
return "", handleNotRunningError(err)
Expand Down
38 changes: 28 additions & 10 deletions commands/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,26 @@ import (

func newRunCmd() *cobra.Command {
var debug bool
var backend string

const cmdArgs = "MODEL [PROMPT]"
c := &cobra.Command{
Use: "run " + cmdArgs,
Short: "Run a model and interact with it using a submitted prompt or chat mode",
RunE: func(cmd *cobra.Command, args []string) error {
// Validate backend if specified
if backend != "" {
if err := validateBackend(backend); err != nil {
return err
}
}

// Validate API key for OpenAI backend
apiKey, err := ensureAPIKey(backend)
if err != nil {
return err
}

model := args[0]
prompt := ""
if len(args) == 1 {
Expand All @@ -37,19 +51,22 @@ func newRunCmd() *cobra.Command {
return fmt.Errorf("unable to initialize standalone model runner: %w", err)
}

_, err := desktopClient.Inspect(model, false)
if err != nil {
if !errors.Is(err, desktop.ErrNotFound) {
return handleNotRunningError(handleClientError(err, "Failed to inspect model"))
}
cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.")
if err := pullModel(cmd, desktopClient, model); err != nil {
return err
// Do not validate the model in case of using OpenAI's backend, let OpenAI handle it
if backend != "openai" {
_, err := desktopClient.Inspect(model, false)
if err != nil {
if !errors.Is(err, desktop.ErrNotFound) {
return handleNotRunningError(handleClientError(err, "Failed to inspect model"))
}
cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.")
if err := pullModel(cmd, desktopClient, model); err != nil {
return err
}
}
}

if prompt != "" {
if err := desktopClient.Chat(model, prompt); err != nil {
if err := desktopClient.Chat(backend, model, prompt, apiKey); err != nil {
return handleClientError(err, "Failed to generate a response")
}
cmd.Println()
Expand All @@ -73,7 +90,7 @@ func newRunCmd() *cobra.Command {
continue
}

if err := desktopClient.Chat(model, userInput); err != nil {
if err := desktopClient.Chat(backend, model, userInput, apiKey); err != nil {
cmd.PrintErr(handleClientError(err, "Failed to generate a response"))
cmd.Print("> ")
continue
Expand Down Expand Up @@ -104,6 +121,7 @@ func newRunCmd() *cobra.Command {
}

c.Flags().BoolVar(&debug, "debug", false, "Enable debug logging")
c.Flags().StringVar(&backend, "backend", "", fmt.Sprintf("Specify the backend to use (%s)", ValidBackendsKeys()))

return c
}
57 changes: 47 additions & 10 deletions desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"go.opentelemetry.io/otel"
)

const DefaultBackend = "llama.cpp"

var (
ErrNotFound = errors.New("model not found")
ErrServiceUnavailable = errors.New("service unavailable")
Expand Down Expand Up @@ -222,14 +224,30 @@ func (c *Client) List() ([]dmrm.Model, error) {
return modelsJson, nil
}

func (c *Client) ListOpenAI() (dmrm.OpenAIModelList, error) {
modelsRoute := inference.InferencePrefix + "/v1/models"
rawResponse, err := c.listRaw(modelsRoute, "")
func (c *Client) ListOpenAI(backend, apiKey string) (dmrm.OpenAIModelList, error) {
if backend == "" {
backend = DefaultBackend
}
modelsRoute := fmt.Sprintf("%s/%s/v1/models", inference.InferencePrefix, backend)

// Use doRequestWithAuth to support API key authentication
resp, err := c.doRequestWithAuth(http.MethodGet, modelsRoute, nil, "openai", apiKey)
if err != nil {
return dmrm.OpenAIModelList{}, c.handleQueryError(err, modelsRoute)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return dmrm.OpenAIModelList{}, fmt.Errorf("failed to list models: %s", resp.Status)
}

body, err := io.ReadAll(resp.Body)
if err != nil {
return dmrm.OpenAIModelList{}, err
return dmrm.OpenAIModelList{}, fmt.Errorf("failed to read response body: %w", err)
}

var modelsJson dmrm.OpenAIModelList
if err := json.Unmarshal(rawResponse, &modelsJson); err != nil {
if err := json.Unmarshal(body, &modelsJson); err != nil {
return modelsJson, fmt.Errorf("failed to unmarshal response body: %w", err)
}
return modelsJson, nil
Expand Down Expand Up @@ -329,7 +347,7 @@ func (c *Client) fullModelID(id string) (string, error) {
return "", fmt.Errorf("model with ID %s not found", id)
}

func (c *Client) Chat(model, prompt string) error {
func (c *Client) Chat(backend, model, prompt, apiKey string) error {
model = normalizeHuggingFaceModelName(model)
if !strings.Contains(strings.Trim(model, "/"), "/") {
// Do an extra API call to check if the model parameter isn't a model ID.
Expand All @@ -354,14 +372,22 @@ func (c *Client) Chat(model, prompt string) error {
return fmt.Errorf("error marshaling request: %w", err)
}

chatCompletionsPath := inference.InferencePrefix + "/v1/chat/completions"
resp, err := c.doRequest(
var completionsPath string
if backend != "" {
completionsPath = inference.InferencePrefix + "/" + backend + "/v1/chat/completions"
} else {
completionsPath = inference.InferencePrefix + "/v1/chat/completions"
}

resp, err := c.doRequestWithAuth(
http.MethodPost,
chatCompletionsPath,
completionsPath,
bytes.NewReader(jsonData),
backend,
apiKey,
)
if err != nil {
return c.handleQueryError(err, chatCompletionsPath)
return c.handleQueryError(err, completionsPath)
}
defer resp.Body.Close()

Expand Down Expand Up @@ -576,6 +602,11 @@ func (c *Client) ConfigureBackend(request scheduling.ConfigureRequest) error {

// doRequest is a helper function that performs HTTP requests and handles 503 responses
func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response, error) {
return c.doRequestWithAuth(method, path, body, "", "")
}

// doRequestWithAuth is a helper function that performs HTTP requests with optional authentication
func (c *Client) doRequestWithAuth(method, path string, body io.Reader, backend, apiKey string) (*http.Response, error) {
req, err := http.NewRequest(method, c.modelRunner.URL(path), body)
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
Expand All @@ -585,6 +616,12 @@ func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response,
}

req.Header.Set("User-Agent", "docker-model-cli/"+Version)

// Add Authorization header for OpenAI backend
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}

resp, err := c.modelRunner.Client().Do(req)
if err != nil {
return nil, err
Expand Down
2 changes: 1 addition & 1 deletion desktop/desktop_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func TestChatHuggingFaceModel(t *testing.T) {
Body: io.NopCloser(bytes.NewBufferString("data: {\"choices\":[{\"delta\":{\"content\":\"Hello there!\"}}]}\n")),
}, nil)

err := client.Chat(modelName, prompt)
err := client.Chat("", modelName, prompt, "")
assert.NoError(t, err)
}

Expand Down
9 changes: 9 additions & 0 deletions docs/reference/docker_model_list.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ usage: docker model list [OPTIONS]
pname: docker model
plink: docker_model.yaml
options:
- option: backend
value_type: string
description: Specify the backend to use (llama.cpp, openai)
deprecated: false
hidden: false
experimental: false
experimentalcli: false
kubernetes: false
swarm: false
- option: json
value_type: bool
default_value: "false"
Expand Down
9 changes: 9 additions & 0 deletions docs/reference/docker_model_run.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ usage: docker model run MODEL [PROMPT]
pname: docker model
plink: docker_model.yaml
options:
- option: backend
value_type: string
description: Specify the backend to use (llama.cpp, openai)
deprecated: false
hidden: false
experimental: false
experimentalcli: false
kubernetes: false
swarm: false
- option: debug
value_type: bool
default_value: "false"
Expand Down
11 changes: 6 additions & 5 deletions docs/reference/model_list.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ List the models pulled to your local environment

### Options

| Name | Type | Default | Description |
|:----------------|:-------|:--------|:--------------------------------|
| `--json` | `bool` | | List models in a JSON format |
| `--openai` | `bool` | | List models in an OpenAI format |
| `-q`, `--quiet` | `bool` | | Only show model IDs |
| Name | Type | Default | Description |
|:----------------|:---------|:--------|:-----------------------------------------------|
| `--backend` | `string` | | Specify the backend to use (llama.cpp, openai) |
| `--json` | `bool` | | List models in a JSON format |
| `--openai` | `bool` | | List models in an OpenAI format |
| `-q`, `--quiet` | `bool` | | Only show model IDs |


<!---MARKER_GEN_END-->
Expand Down
7 changes: 4 additions & 3 deletions docs/reference/model_run.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ Run a model and interact with it using a submitted prompt or chat mode

### Options

| Name | Type | Default | Description |
|:----------|:-------|:--------|:---------------------|
| `--debug` | `bool` | | Enable debug logging |
| Name | Type | Default | Description |
|:------------|:---------|:--------|:-----------------------------------------------|
| `--backend` | `string` | | Specify the backend to use (llama.cpp, openai) |
| `--debug` | `bool` | | Enable debug logging |


<!---MARKER_GEN_END-->
Expand Down