Skip to content

Commit d3a94ec

Browse files
authored
Merge pull request #2304 from masegraye/task/cagent
feat: add `docker agent models` command
2 parents 99fe1b0 + 54f0e1a commit d3a94ec

File tree

3 files changed

+396
-0
lines changed

3 files changed

+396
-0
lines changed

cmd/root/models.go

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
package root
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"slices"
8+
"strings"
9+
"text/tabwriter"
10+
11+
"github.com/spf13/cobra"
12+
13+
"github.com/docker/docker-agent/pkg/cli"
14+
"github.com/docker/docker-agent/pkg/config"
15+
"github.com/docker/docker-agent/pkg/config/latest"
16+
"github.com/docker/docker-agent/pkg/model/provider"
17+
"github.com/docker/docker-agent/pkg/modelsdev"
18+
"github.com/docker/docker-agent/pkg/telemetry"
19+
)
20+
21+
type modelsListFlags struct {
22+
providerFilter string
23+
format string
24+
all bool
25+
runConfig config.RuntimeConfig
26+
}
27+
28+
// modelRow represents a single model entry for display or serialization.
29+
type modelRow struct {
30+
Provider string `json:"provider"`
31+
Model string `json:"model"`
32+
Default bool `json:"default,omitempty"`
33+
}
34+
35+
func newModelsCmd() *cobra.Command {
36+
cmd := &cobra.Command{
37+
Use: "models",
38+
Short: "List available models",
39+
Long: `List models available for use with --model flag.
40+
41+
Shows models that can be passed to 'docker agent run --model' or
42+
'docker agent new --model'. By default shows models from providers
43+
you have credentials for. Use --all to include all providers.`,
44+
GroupID: "core",
45+
}
46+
47+
listCmd := newModelsListCmd()
48+
cmd.AddCommand(listCmd)
49+
50+
// Default to "list" when no subcommand given.
51+
cmd.RunE = listCmd.RunE
52+
53+
// Copy the flags from the list command so they work on the bare
54+
// "docker agent models --provider openai" form as well.
55+
cmd.Flags().AddFlagSet(listCmd.Flags())
56+
57+
return cmd
58+
}
59+
60+
func newModelsListCmd() *cobra.Command {
61+
var flags modelsListFlags
62+
63+
cmd := &cobra.Command{
64+
Use: "list",
65+
Aliases: []string{"ls"},
66+
Short: "List available models",
67+
Example: ` docker agent models
68+
docker agent models list --provider openai
69+
docker agent models ls --all
70+
docker agent models --format json`,
71+
Args: cobra.NoArgs,
72+
RunE: flags.runModelsListCommand,
73+
}
74+
75+
cmd.Flags().StringVarP(&flags.providerFilter, "provider", "p", "", "Filter by provider name")
76+
cmd.Flags().StringVar(&flags.format, "format", "table", "Output format: table, json")
77+
cmd.Flags().BoolVarP(&flags.all, "all", "a", false, "Include models from all providers, not just those with credentials")
78+
addGatewayFlags(cmd, &flags.runConfig)
79+
80+
return cmd
81+
}
82+
83+
func (f *modelsListFlags) runModelsListCommand(cmd *cobra.Command, args []string) (commandErr error) {
84+
ctx := cmd.Context()
85+
telemetry.TrackCommand(ctx, "models", append([]string{"list"}, args...))
86+
defer func() {
87+
telemetry.TrackCommandError(ctx, "models", append([]string{"list"}, args...), commandErr)
88+
}()
89+
90+
out := cli.NewPrinter(cmd.OutOrStdout())
91+
env := f.runConfig.EnvProvider()
92+
93+
// Determine which providers the user has credentials for.
94+
availableProviders := make(map[string]bool)
95+
for _, p := range config.AvailableProviders(ctx, f.runConfig.ModelsGateway, env) {
96+
availableProviders[p] = true
97+
}
98+
99+
// Determine which model auto-selection would pick.
100+
autoModel := config.AutoModelConfig(ctx, f.runConfig.ModelsGateway, env, f.runConfig.DefaultModel)
101+
102+
rows := f.collectModels(ctx, availableProviders, autoModel)
103+
104+
// Apply provider filter
105+
if f.providerFilter != "" {
106+
rows = slices.DeleteFunc(rows, func(r modelRow) bool {
107+
return !strings.EqualFold(r.Provider, f.providerFilter)
108+
})
109+
}
110+
111+
// Sort: default first, then by provider, then by model
112+
slices.SortFunc(rows, func(a, b modelRow) int {
113+
if a.Default != b.Default {
114+
if a.Default {
115+
return -1
116+
}
117+
return 1
118+
}
119+
if c := strings.Compare(a.Provider, b.Provider); c != 0 {
120+
return c
121+
}
122+
return strings.Compare(a.Model, b.Model)
123+
})
124+
125+
if len(rows) == 0 {
126+
out.Println("No models available.")
127+
out.Println("\nConfigure a provider API key or install Docker Model Runner.")
128+
return nil
129+
}
130+
131+
switch f.format {
132+
case "json":
133+
return f.renderJSON(cmd, rows)
134+
default:
135+
f.renderTable(cmd, rows)
136+
}
137+
138+
return nil
139+
}
140+
141+
// collectModels returns all models from the catalog, filtered by credential
142+
// availability unless --all is set. Default models for each available provider
143+
// are always included even if the catalog fetch fails.
144+
func (f *modelsListFlags) collectModels(ctx context.Context, availableProviders map[string]bool, autoModel latest.ModelConfig) []modelRow {
145+
seen := make(map[string]bool)
146+
var rows []modelRow
147+
148+
// Always include the per-provider defaults so we have something even
149+
// if the catalog is unreachable.
150+
for prov, model := range config.DefaultModels {
151+
if !f.all && !availableProviders[prov] {
152+
continue
153+
}
154+
ref := prov + "/" + model
155+
seen[ref] = true
156+
rows = append(rows, modelRow{
157+
Provider: prov,
158+
Model: model,
159+
Default: prov == autoModel.Provider && model == autoModel.Model,
160+
})
161+
}
162+
163+
// Fetch catalog and add all text-capable models.
164+
store, err := modelsdev.NewStore()
165+
if err != nil {
166+
return rows
167+
}
168+
db, err := store.GetDatabase(ctx)
169+
if err != nil {
170+
return rows
171+
}
172+
173+
for providerID, prov := range db.Providers {
174+
if !provider.IsCatalogProvider(providerID) {
175+
continue
176+
}
177+
if !f.all && !availableProviders[providerID] {
178+
continue
179+
}
180+
for modelID, model := range prov.Models {
181+
if !slices.Contains(model.Modalities.Output, "text") {
182+
continue
183+
}
184+
if isEmbeddingModel(model.Family, model.Name) {
185+
continue
186+
}
187+
188+
ref := providerID + "/" + modelID
189+
if seen[ref] {
190+
continue
191+
}
192+
seen[ref] = true
193+
194+
rows = append(rows, modelRow{
195+
Provider: providerID,
196+
Model: modelID,
197+
})
198+
}
199+
}
200+
201+
return rows
202+
}
203+
204+
func isEmbeddingModel(family, name string) bool {
205+
fl := strings.ToLower(family)
206+
nl := strings.ToLower(name)
207+
return strings.Contains(fl, "embed") || strings.Contains(nl, "embed")
208+
}
209+
210+
func (f *modelsListFlags) renderTable(cmd *cobra.Command, rows []modelRow) {
211+
w := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 2, 3, ' ', 0)
212+
fmt.Fprintln(w, "PROVIDER\tMODEL\tDEFAULT")
213+
for _, r := range rows {
214+
def := ""
215+
if r.Default {
216+
def = "*"
217+
}
218+
fmt.Fprintf(w, "%s\t%s\t%s\n", r.Provider, r.Model, def)
219+
}
220+
w.Flush()
221+
}
222+
223+
func (f *modelsListFlags) renderJSON(cmd *cobra.Command, rows []modelRow) error {
224+
enc := json.NewEncoder(cmd.OutOrStdout())
225+
enc.SetIndent("", " ")
226+
return enc.Encode(rows)
227+
}

cmd/root/models_test.go

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
package root
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"strings"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
12+
"github.com/docker/docker-agent/pkg/config"
13+
"github.com/docker/docker-agent/pkg/userconfig"
14+
)
15+
16+
func TestModelsListCommand_DefaultOutput(t *testing.T) {
17+
// With ANTHROPIC_API_KEY set, the default output should include
18+
// at least the anthropic default model.
19+
t.Setenv("ANTHROPIC_API_KEY", "test-key")
20+
t.Setenv("DOCKER_AGENT_MODELS_GATEWAY", "")
21+
t.Setenv("DOCKER_AGENT_DEFAULT_MODEL", "")
22+
23+
original := loadUserConfig
24+
loadUserConfig = func() (*userconfig.Config, error) { return &userconfig.Config{}, nil }
25+
t.Cleanup(func() { loadUserConfig = original })
26+
27+
var buf bytes.Buffer
28+
cmd := newModelsCmd()
29+
cmd.SetOut(&buf)
30+
cmd.SetErr(&buf)
31+
cmd.SetArgs(nil)
32+
33+
err := cmd.Execute()
34+
require.NoError(t, err)
35+
36+
output := buf.String()
37+
assert.Contains(t, output, "PROVIDER")
38+
assert.Contains(t, output, "MODEL")
39+
assert.Contains(t, output, "anthropic")
40+
}
41+
42+
func TestModelsListCommand_ProviderFilter(t *testing.T) {
43+
t.Setenv("ANTHROPIC_API_KEY", "test-key")
44+
t.Setenv("OPENAI_API_KEY", "test-key")
45+
t.Setenv("DOCKER_AGENT_MODELS_GATEWAY", "")
46+
t.Setenv("DOCKER_AGENT_DEFAULT_MODEL", "")
47+
48+
original := loadUserConfig
49+
loadUserConfig = func() (*userconfig.Config, error) { return &userconfig.Config{}, nil }
50+
t.Cleanup(func() { loadUserConfig = original })
51+
52+
var buf bytes.Buffer
53+
cmd := newModelsCmd()
54+
cmd.SetOut(&buf)
55+
cmd.SetErr(&buf)
56+
cmd.SetArgs([]string{"--provider", "anthropic"})
57+
58+
err := cmd.Execute()
59+
require.NoError(t, err)
60+
61+
output := buf.String()
62+
// Every non-header line should be anthropic
63+
for line := range strings.SplitSeq(output, "\n") {
64+
line = strings.TrimSpace(line)
65+
if line == "" || strings.HasPrefix(line, "PROVIDER") {
66+
continue
67+
}
68+
assert.True(t, strings.HasPrefix(line, "anthropic"),
69+
"expected anthropic provider, got: %s", line)
70+
}
71+
}
72+
73+
func TestModelsListCommand_JSONFormat(t *testing.T) {
74+
t.Setenv("ANTHROPIC_API_KEY", "test-key")
75+
t.Setenv("DOCKER_AGENT_MODELS_GATEWAY", "")
76+
t.Setenv("DOCKER_AGENT_DEFAULT_MODEL", "")
77+
78+
original := loadUserConfig
79+
loadUserConfig = func() (*userconfig.Config, error) { return &userconfig.Config{}, nil }
80+
t.Cleanup(func() { loadUserConfig = original })
81+
82+
var buf bytes.Buffer
83+
cmd := newModelsCmd()
84+
cmd.SetOut(&buf)
85+
cmd.SetErr(&buf)
86+
cmd.SetArgs([]string{"--format", "json"})
87+
88+
err := cmd.Execute()
89+
require.NoError(t, err)
90+
91+
var rows []modelRow
92+
err = json.Unmarshal(buf.Bytes(), &rows)
93+
require.NoError(t, err)
94+
assert.NotEmpty(t, rows)
95+
96+
// At least one should be the default
97+
hasDefault := false
98+
for _, r := range rows {
99+
if r.Default {
100+
hasDefault = true
101+
break
102+
}
103+
}
104+
assert.True(t, hasDefault, "expected at least one default model")
105+
}
106+
107+
func TestModelsListCommand_DefaultMarker(t *testing.T) {
108+
// When a default model is configured via env, it should be marked.
109+
t.Setenv("ANTHROPIC_API_KEY", "test-key")
110+
t.Setenv("DOCKER_AGENT_MODELS_GATEWAY", "")
111+
t.Setenv("DOCKER_AGENT_DEFAULT_MODEL", "")
112+
113+
original := loadUserConfig
114+
loadUserConfig = func() (*userconfig.Config, error) { return &userconfig.Config{}, nil }
115+
t.Cleanup(func() { loadUserConfig = original })
116+
117+
var buf bytes.Buffer
118+
cmd := newModelsCmd()
119+
cmd.SetOut(&buf)
120+
cmd.SetErr(&buf)
121+
cmd.SetArgs([]string{"--format", "json"})
122+
123+
err := cmd.Execute()
124+
require.NoError(t, err)
125+
126+
var rows []modelRow
127+
require.NoError(t, json.Unmarshal(buf.Bytes(), &rows))
128+
129+
// The auto-selected model should be marked as default
130+
rc := config.RuntimeConfig{}
131+
autoModel := config.AutoModelConfig(t.Context(), "", rc.EnvProvider(), nil)
132+
for _, r := range rows {
133+
if r.Provider == autoModel.Provider && r.Model == autoModel.Model {
134+
assert.True(t, r.Default, "auto-selected model %s/%s should be marked as default", r.Provider, r.Model)
135+
}
136+
}
137+
}
138+
139+
func TestModelsListCommand_NoCredentials(t *testing.T) {
140+
// Clear all provider keys — only DMR should remain as fallback.
141+
t.Setenv("ANTHROPIC_API_KEY", "")
142+
t.Setenv("OPENAI_API_KEY", "")
143+
t.Setenv("GOOGLE_API_KEY", "")
144+
t.Setenv("GEMINI_API_KEY", "")
145+
t.Setenv("MISTRAL_API_KEY", "")
146+
t.Setenv("AWS_ACCESS_KEY_ID", "")
147+
t.Setenv("AWS_PROFILE", "")
148+
t.Setenv("AWS_ROLE_ARN", "")
149+
t.Setenv("DOCKER_AGENT_MODELS_GATEWAY", "")
150+
t.Setenv("DOCKER_AGENT_DEFAULT_MODEL", "")
151+
152+
original := loadUserConfig
153+
loadUserConfig = func() (*userconfig.Config, error) { return &userconfig.Config{}, nil }
154+
t.Cleanup(func() { loadUserConfig = original })
155+
156+
var buf bytes.Buffer
157+
cmd := newModelsCmd()
158+
cmd.SetOut(&buf)
159+
cmd.SetErr(&buf)
160+
cmd.SetArgs(nil)
161+
162+
err := cmd.Execute()
163+
require.NoError(t, err)
164+
165+
output := buf.String()
166+
// DMR is always available as fallback
167+
assert.Contains(t, output, "dmr")
168+
}

0 commit comments

Comments
 (0)