diff --git a/commands/backend.go b/commands/backend.go new file mode 100644 index 00000000..35fe87a0 --- /dev/null +++ b/commands/backend.go @@ -0,0 +1,41 @@ +package commands + +import ( + "errors" + "fmt" + "maps" + "os" + "slices" + "strings" +) + +// ValidBackends is a map of valid backends +var ValidBackends = map[string]bool{ + "llama.cpp": true, + "openai": true, +} + +// validateBackend checks if the provided backend is valid +func validateBackend(backend string) error { + if !ValidBackends[backend] { + return fmt.Errorf("invalid backend '%s'. Valid backends are: %s", + backend, ValidBackendsKeys()) + } + return nil +} + +// ensureAPIKey retrieves the API key if needed +func ensureAPIKey(backend string) (string, error) { + if backend == "openai" { + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + return "", errors.New("OPENAI_API_KEY environment variable is required when using --backend=openai") + } + return apiKey, nil + } + return "", nil +} + +func ValidBackendsKeys() string { + return strings.Join(slices.Collect(maps.Keys(ValidBackends)), ", ") +} diff --git a/commands/list.go b/commands/list.go index 91627d09..7ea1d44f 100644 --- a/commands/list.go +++ b/commands/list.go @@ -18,24 +18,39 @@ import ( func newListCmd() *cobra.Command { var jsonFormat, openai, quiet bool + var backend string c := &cobra.Command{ Use: "list [OPTIONS]", Aliases: []string{"ls"}, Short: "List the models pulled to your local environment", RunE: func(cmd *cobra.Command, args []string) error { - if openai && quiet { - return fmt.Errorf("--quiet flag cannot be used with --openai flag") + // Validate backend if specified + if backend != "" { + if err := validateBackend(backend); err != nil { + return err + } } + + if (backend == "openai" || openai) && quiet { + return fmt.Errorf("--quiet flag cannot be used with --openai flag or OpenAI backend") + } + + // Validate API key for OpenAI backend + apiKey, err := ensureAPIKey(backend) + if err != nil { + return err + } + // If we're doing an automatic install, only show the installation // status if it won't corrupt machine-readable output. var standaloneInstallPrinter standalone.StatusPrinter - if !jsonFormat && !openai && !quiet { + if !jsonFormat && !openai && !quiet && backend == "" { standaloneInstallPrinter = cmd } if _, err := ensureStandaloneRunnerAvailable(cmd.Context(), standaloneInstallPrinter); err != nil { return fmt.Errorf("unable to initialize standalone model runner: %w", err) } - models, err := listModels(openai, desktopClient, quiet, jsonFormat) + models, err := listModels(openai, backend, desktopClient, quiet, jsonFormat, apiKey) if err != nil { return err } @@ -47,12 +62,13 @@ func newListCmd() *cobra.Command { c.Flags().BoolVar(&jsonFormat, "json", false, "List models in a JSON format") c.Flags().BoolVar(&openai, "openai", false, "List models in an OpenAI format") c.Flags().BoolVarP(&quiet, "quiet", "q", false, "Only show model IDs") + c.Flags().StringVar(&backend, "backend", "", fmt.Sprintf("Specify the backend to use (%s)", ValidBackendsKeys())) return c } -func listModels(openai bool, desktopClient *desktop.Client, quiet bool, jsonFormat bool) (string, error) { - if openai { - models, err := desktopClient.ListOpenAI() +func listModels(openai bool, backend string, desktopClient *desktop.Client, quiet bool, jsonFormat bool, apiKey string) (string, error) { + if openai || backend == "openai" { + models, err := desktopClient.ListOpenAI(backend, apiKey) if err != nil { err = handleClientError(err, "Failed to list models") return "", handleNotRunningError(err) diff --git a/commands/run.go b/commands/run.go index cbe4f060..29336274 100644 --- a/commands/run.go +++ b/commands/run.go @@ -14,12 +14,26 @@ import ( func newRunCmd() *cobra.Command { var debug bool + var backend string const cmdArgs = "MODEL [PROMPT]" c := &cobra.Command{ Use: "run " + cmdArgs, Short: "Run a model and interact with it using a submitted prompt or chat mode", RunE: func(cmd *cobra.Command, args []string) error { + // Validate backend if specified + if backend != "" { + if err := validateBackend(backend); err != nil { + return err + } + } + + // Validate API key for OpenAI backend + apiKey, err := ensureAPIKey(backend) + if err != nil { + return err + } + model := args[0] prompt := "" if len(args) == 1 { @@ -37,19 +51,22 @@ func newRunCmd() *cobra.Command { return fmt.Errorf("unable to initialize standalone model runner: %w", err) } - _, err := desktopClient.Inspect(model, false) - if err != nil { - if !errors.Is(err, desktop.ErrNotFound) { - return handleNotRunningError(handleClientError(err, "Failed to inspect model")) - } - cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.") - if err := pullModel(cmd, desktopClient, model); err != nil { - return err + // Do not validate the model in case of using OpenAI's backend, let OpenAI handle it + if backend != "openai" { + _, err := desktopClient.Inspect(model, false) + if err != nil { + if !errors.Is(err, desktop.ErrNotFound) { + return handleNotRunningError(handleClientError(err, "Failed to inspect model")) + } + cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.") + if err := pullModel(cmd, desktopClient, model); err != nil { + return err + } } } if prompt != "" { - if err := desktopClient.Chat(model, prompt); err != nil { + if err := desktopClient.Chat(backend, model, prompt, apiKey); err != nil { return handleClientError(err, "Failed to generate a response") } cmd.Println() @@ -73,7 +90,7 @@ func newRunCmd() *cobra.Command { continue } - if err := desktopClient.Chat(model, userInput); err != nil { + if err := desktopClient.Chat(backend, model, userInput, apiKey); err != nil { cmd.PrintErr(handleClientError(err, "Failed to generate a response")) cmd.Print("> ") continue @@ -104,6 +121,7 @@ func newRunCmd() *cobra.Command { } c.Flags().BoolVar(&debug, "debug", false, "Enable debug logging") + c.Flags().StringVar(&backend, "backend", "", fmt.Sprintf("Specify the backend to use (%s)", ValidBackendsKeys())) return c } diff --git a/desktop/desktop.go b/desktop/desktop.go index 6506d308..e4f5e7a1 100644 --- a/desktop/desktop.go +++ b/desktop/desktop.go @@ -19,6 +19,8 @@ import ( "go.opentelemetry.io/otel" ) +const DefaultBackend = "llama.cpp" + var ( ErrNotFound = errors.New("model not found") ErrServiceUnavailable = errors.New("service unavailable") @@ -222,14 +224,30 @@ func (c *Client) List() ([]dmrm.Model, error) { return modelsJson, nil } -func (c *Client) ListOpenAI() (dmrm.OpenAIModelList, error) { - modelsRoute := inference.InferencePrefix + "/v1/models" - rawResponse, err := c.listRaw(modelsRoute, "") +func (c *Client) ListOpenAI(backend, apiKey string) (dmrm.OpenAIModelList, error) { + if backend == "" { + backend = DefaultBackend + } + modelsRoute := fmt.Sprintf("%s/%s/v1/models", inference.InferencePrefix, backend) + + // Use doRequestWithAuth to support API key authentication + resp, err := c.doRequestWithAuth(http.MethodGet, modelsRoute, nil, "openai", apiKey) + if err != nil { + return dmrm.OpenAIModelList{}, c.handleQueryError(err, modelsRoute) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return dmrm.OpenAIModelList{}, fmt.Errorf("failed to list models: %s", resp.Status) + } + + body, err := io.ReadAll(resp.Body) if err != nil { - return dmrm.OpenAIModelList{}, err + return dmrm.OpenAIModelList{}, fmt.Errorf("failed to read response body: %w", err) } + var modelsJson dmrm.OpenAIModelList - if err := json.Unmarshal(rawResponse, &modelsJson); err != nil { + if err := json.Unmarshal(body, &modelsJson); err != nil { return modelsJson, fmt.Errorf("failed to unmarshal response body: %w", err) } return modelsJson, nil @@ -329,7 +347,7 @@ func (c *Client) fullModelID(id string) (string, error) { return "", fmt.Errorf("model with ID %s not found", id) } -func (c *Client) Chat(model, prompt string) error { +func (c *Client) Chat(backend, model, prompt, apiKey string) error { model = normalizeHuggingFaceModelName(model) if !strings.Contains(strings.Trim(model, "/"), "/") { // Do an extra API call to check if the model parameter isn't a model ID. @@ -354,14 +372,22 @@ func (c *Client) Chat(model, prompt string) error { return fmt.Errorf("error marshaling request: %w", err) } - chatCompletionsPath := inference.InferencePrefix + "/v1/chat/completions" - resp, err := c.doRequest( + var completionsPath string + if backend != "" { + completionsPath = inference.InferencePrefix + "/" + backend + "/v1/chat/completions" + } else { + completionsPath = inference.InferencePrefix + "/v1/chat/completions" + } + + resp, err := c.doRequestWithAuth( http.MethodPost, - chatCompletionsPath, + completionsPath, bytes.NewReader(jsonData), + backend, + apiKey, ) if err != nil { - return c.handleQueryError(err, chatCompletionsPath) + return c.handleQueryError(err, completionsPath) } defer resp.Body.Close() @@ -576,6 +602,11 @@ func (c *Client) ConfigureBackend(request scheduling.ConfigureRequest) error { // doRequest is a helper function that performs HTTP requests and handles 503 responses func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response, error) { + return c.doRequestWithAuth(method, path, body, "", "") +} + +// doRequestWithAuth is a helper function that performs HTTP requests with optional authentication +func (c *Client) doRequestWithAuth(method, path string, body io.Reader, backend, apiKey string) (*http.Response, error) { req, err := http.NewRequest(method, c.modelRunner.URL(path), body) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) @@ -585,6 +616,12 @@ func (c *Client) doRequest(method, path string, body io.Reader) (*http.Response, } req.Header.Set("User-Agent", "docker-model-cli/"+Version) + + // Add Authorization header for OpenAI backend + if apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } + resp, err := c.modelRunner.Client().Do(req) if err != nil { return nil, err diff --git a/desktop/desktop_test.go b/desktop/desktop_test.go index 71ea9b9f..0654f7e9 100644 --- a/desktop/desktop_test.go +++ b/desktop/desktop_test.go @@ -63,7 +63,7 @@ func TestChatHuggingFaceModel(t *testing.T) { Body: io.NopCloser(bytes.NewBufferString("data: {\"choices\":[{\"delta\":{\"content\":\"Hello there!\"}}]}\n")), }, nil) - err := client.Chat(modelName, prompt) + err := client.Chat("", modelName, prompt, "") assert.NoError(t, err) } diff --git a/docs/reference/docker_model_list.yaml b/docs/reference/docker_model_list.yaml index 292704ad..f157f09a 100644 --- a/docs/reference/docker_model_list.yaml +++ b/docs/reference/docker_model_list.yaml @@ -6,6 +6,15 @@ usage: docker model list [OPTIONS] pname: docker model plink: docker_model.yaml options: + - option: backend + value_type: string + description: Specify the backend to use (llama.cpp, openai) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false - option: json value_type: bool default_value: "false" diff --git a/docs/reference/docker_model_run.yaml b/docs/reference/docker_model_run.yaml index 4d18d3c6..a88f6c54 100644 --- a/docs/reference/docker_model_run.yaml +++ b/docs/reference/docker_model_run.yaml @@ -10,6 +10,15 @@ usage: docker model run MODEL [PROMPT] pname: docker model plink: docker_model.yaml options: + - option: backend + value_type: string + description: Specify the backend to use (llama.cpp, openai) + deprecated: false + hidden: false + experimental: false + experimentalcli: false + kubernetes: false + swarm: false - option: debug value_type: bool default_value: "false" diff --git a/docs/reference/model_list.md b/docs/reference/model_list.md index b6c051f2..947cd831 100644 --- a/docs/reference/model_list.md +++ b/docs/reference/model_list.md @@ -9,11 +9,12 @@ List the models pulled to your local environment ### Options -| Name | Type | Default | Description | -|:----------------|:-------|:--------|:--------------------------------| -| `--json` | `bool` | | List models in a JSON format | -| `--openai` | `bool` | | List models in an OpenAI format | -| `-q`, `--quiet` | `bool` | | Only show model IDs | +| Name | Type | Default | Description | +|:----------------|:---------|:--------|:-----------------------------------------------| +| `--backend` | `string` | | Specify the backend to use (llama.cpp, openai) | +| `--json` | `bool` | | List models in a JSON format | +| `--openai` | `bool` | | List models in an OpenAI format | +| `-q`, `--quiet` | `bool` | | Only show model IDs | diff --git a/docs/reference/model_run.md b/docs/reference/model_run.md index 3010f26c..55b427ca 100644 --- a/docs/reference/model_run.md +++ b/docs/reference/model_run.md @@ -5,9 +5,10 @@ Run a model and interact with it using a submitted prompt or chat mode ### Options -| Name | Type | Default | Description | -|:----------|:-------|:--------|:---------------------| -| `--debug` | `bool` | | Enable debug logging | +| Name | Type | Default | Description | +|:------------|:---------|:--------|:-----------------------------------------------| +| `--backend` | `string` | | Specify the backend to use (llama.cpp, openai) | +| `--debug` | `bool` | | Enable debug logging |