Skip to content

Commit e660d0b

Browse files
author
Marek Safarik
committed
generic model listing and runtime model switching
Signed-off-by: Marek Safarik <msafarik@redhat.com>
1 parent 3cd9b15 commit e660d0b

File tree

8 files changed

+402
-37
lines changed

8 files changed

+402
-37
lines changed

cmd/client/main.go

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ func main() {
5353
return
5454
}
5555

56-
provider, model := createProvider()
56+
providers, initialProvider, initialModel := createProviders()
5757

5858
cfg := agent.Config{ServerPath: serverPath}
59-
if model != "" {
60-
cfg.Model = model
59+
if initialModel != "" {
60+
cfg.Model = initialModel
6161
}
6262

63-
agentInstance := agent.NewAgent(cfg, provider)
63+
agentInstance := agent.NewAgent(cfg, initialProvider)
6464

6565
if err := agentInstance.Connect(ctx); err != nil {
6666
log.Printf("Failed to connect to MCP server: %v", err)
@@ -74,7 +74,7 @@ func main() {
7474
return
7575
}
7676

77-
srv, err := web.NewServer(ctx, agentInstance)
77+
srv, err := web.NewServer(ctx, agentInstance, providers)
7878
if err != nil {
7979
log.Printf("Failed to create web server: %v", err)
8080
return
@@ -89,26 +89,35 @@ func main() {
8989
}
9090
}
9191

92-
// createProvider selects the LLM provider based on environment variables.
93-
// OLLAMA_URL or OLLAMA_MODEL → local Ollama (Anthropic-compatible API)
94-
// ANTHROPIC_API_KEY → Claude cloud API
95-
func createProvider() (agent.LLMProvider, string) {
92+
func createProviders() ([]agent.LLMProvider, agent.LLMProvider, string) {
9693
ollamaURL := os.Getenv("OLLAMA_URL")
9794
ollamaModel := os.Getenv("OLLAMA_MODEL")
95+
apiKey := strings.TrimSpace(os.Getenv("ANTHROPIC_API_KEY"))
96+
97+
if ollamaURL == "" {
98+
ollamaURL = defaultOllamaURL
99+
}
100+
101+
var providers []agent.LLMProvider
98102

99-
if ollamaURL != "" || ollamaModel != "" {
100-
if ollamaURL == "" {
101-
ollamaURL = defaultOllamaURL
102-
}
103+
var claudeProvider *agent.AnthropicProvider
104+
if apiKey != "" {
105+
claudeProvider = agent.NewClaudeProvider(apiKey)
106+
providers = append(providers, claudeProvider)
107+
}
108+
109+
ollamaProvider := agent.NewOllamaProvider(ollamaURL)
110+
providers = append(providers, ollamaProvider)
111+
112+
if ollamaModel != "" || os.Getenv("OLLAMA_URL") != "" {
103113
log.Printf("Using Ollama provider at %s", ollamaURL)
104-
return agent.NewOllamaProvider(ollamaURL), ollamaModel
114+
return providers, ollamaProvider, ollamaModel
105115
}
106116

107-
apiKey := strings.TrimSpace(os.Getenv("ANTHROPIC_API_KEY"))
108-
if apiKey == "" {
117+
if claudeProvider == nil {
109118
log.Fatal("Set ANTHROPIC_API_KEY for Claude or OLLAMA_URL/OLLAMA_MODEL for local Ollama")
110119
}
111120

112121
log.Printf("Using Claude provider")
113-
return agent.NewClaudeProvider(apiKey), ""
122+
return providers, claudeProvider, ""
114123
}

internal/agent/agent.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,12 @@ func (a *Agent) Reset() {
228228
a.messages = []Message{}
229229
a.toolQueue = nil
230230
}
231+
232+
func (a *Agent) SetModel(provider LLMProvider, model string) {
233+
a.provider = provider
234+
a.config.Model = model
235+
}
236+
237+
func (a *Agent) GetModel() string {
238+
return a.config.Model
239+
}

internal/agent/provider.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@ import (
88

99
// LLMProvider defines the interface for LLM backends (Adapter pattern).
1010
// Each provider converts generic Messages and MCP tools into its native format.
11+
// Implementing a new provider requires only a new provider_X.go file.
1112
type LLMProvider interface {
1213
Chat(ctx context.Context, opts ChatOptions) (*LLMResponse, error)
14+
ListModels(ctx context.Context) ([]ModelInfo, error)
15+
Name() string
1316
}
1417

1518
// ChatOptions contains all parameters needed for an LLM API call.

internal/agent/provider_anthropic.go

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,21 @@ import (
66

77
"github.com/anthropics/anthropic-sdk-go"
88
"github.com/anthropics/anthropic-sdk-go/option"
9+
"github.com/anthropics/anthropic-sdk-go/packages/param"
910
"github.com/modelcontextprotocol/go-sdk/mcp"
1011
)
1112

12-
type AnthropicProvider struct {
13+
// anthropicBase contains the shared Chat implementation for all providers
14+
// that use the Anthropic-compatible API (Anthropic cloud, Ollama, etc.).
15+
type anthropicBase struct {
1316
client anthropic.Client
1417
}
1518

16-
func NewClaudeProvider(apiKey string) *AnthropicProvider {
17-
return &AnthropicProvider{
18-
client: anthropic.NewClient(option.WithAPIKey(apiKey)),
19-
}
20-
}
21-
22-
func NewOllamaProvider(baseURL string) *AnthropicProvider {
23-
return &AnthropicProvider{
24-
client: anthropic.NewClient(
25-
option.WithBaseURL(baseURL),
26-
option.WithAPIKey("ollama"),
27-
),
28-
}
29-
}
30-
31-
func (p *AnthropicProvider) Chat(ctx context.Context, opts ChatOptions) (*LLMResponse, error) {
19+
func (b *anthropicBase) Chat(ctx context.Context, opts ChatOptions) (*LLMResponse, error) {
3220
messages := convertMessagesToAnthropic(opts.Messages)
3321
tools := convertToolsToAnthropic(opts.Tools)
3422

35-
response, err := p.client.Messages.New(ctx, anthropic.MessageNewParams{
23+
response, err := b.client.Messages.New(ctx, anthropic.MessageNewParams{
3624
Model: anthropic.Model(opts.Model),
3725
MaxTokens: opts.MaxTokens,
3826
System: []anthropic.TextBlockParam{{Type: "text", Text: opts.SystemPrompt}},
@@ -46,6 +34,40 @@ func (p *AnthropicProvider) Chat(ctx context.Context, opts ChatOptions) (*LLMRes
4634
return parseAnthropicResponse(response), nil
4735
}
4836

37+
type AnthropicProvider struct {
38+
anthropicBase
39+
}
40+
41+
func NewClaudeProvider(apiKey string) *AnthropicProvider {
42+
return &AnthropicProvider{
43+
anthropicBase: anthropicBase{
44+
client: anthropic.NewClient(option.WithAPIKey(apiKey)),
45+
},
46+
}
47+
}
48+
49+
func (p *AnthropicProvider) Name() string { return "anthropic" }
50+
51+
func (p *AnthropicProvider) ListModels(ctx context.Context) ([]ModelInfo, error) {
52+
page, err := p.client.Models.List(ctx, anthropic.ModelListParams{
53+
Limit: param.NewOpt[int64](1000),
54+
})
55+
if err != nil {
56+
return nil, fmt.Errorf("failed to list Anthropic models: %w", err)
57+
}
58+
59+
models := make([]ModelInfo, 0, len(page.Data))
60+
for _, m := range page.Data {
61+
models = append(models, ModelInfo{
62+
ID: m.ID,
63+
DisplayName: m.DisplayName,
64+
Provider: "anthropic",
65+
})
66+
}
67+
68+
return models, nil
69+
}
70+
4971
func convertMessagesToAnthropic(messages []Message) []anthropic.MessageParam {
5072
result := make([]anthropic.MessageParam, 0, len(messages))
5173

internal/agent/provider_ollama.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package agent
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"time"
10+
11+
"github.com/anthropics/anthropic-sdk-go"
12+
"github.com/anthropics/anthropic-sdk-go/option"
13+
)
14+
15+
type OllamaProvider struct {
16+
anthropicBase
17+
baseURL string
18+
}
19+
20+
func NewOllamaProvider(baseURL string) *OllamaProvider {
21+
return &OllamaProvider{
22+
anthropicBase: anthropicBase{
23+
client: anthropic.NewClient(
24+
option.WithBaseURL(baseURL),
25+
option.WithAPIKey("ollama"),
26+
),
27+
},
28+
baseURL: baseURL,
29+
}
30+
}
31+
32+
func (p *OllamaProvider) Name() string { return "ollama" }
33+
34+
func (p *OllamaProvider) ListModels(ctx context.Context) ([]ModelInfo, error) {
35+
httpClient := &http.Client{Timeout: 5 * time.Second}
36+
37+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.baseURL+"/api/tags", nil)
38+
if err != nil {
39+
return nil, fmt.Errorf("failed to create Ollama request: %w", err)
40+
}
41+
42+
resp, err := httpClient.Do(req)
43+
if err != nil {
44+
return nil, fmt.Errorf("failed to reach Ollama at %s: %w", p.baseURL, err)
45+
}
46+
defer resp.Body.Close()
47+
48+
body, err := io.ReadAll(resp.Body)
49+
if err != nil {
50+
return nil, fmt.Errorf("failed to read Ollama response: %w", err)
51+
}
52+
53+
if resp.StatusCode != http.StatusOK {
54+
return nil, fmt.Errorf("Ollama API returned status %d: %s", resp.StatusCode, string(body))
55+
}
56+
57+
var tagsResp struct {
58+
Models []struct {
59+
Name string `json:"name"`
60+
} `json:"models"`
61+
}
62+
if err := json.Unmarshal(body, &tagsResp); err != nil {
63+
return nil, fmt.Errorf("failed to parse Ollama response: %w", err)
64+
}
65+
66+
models := make([]ModelInfo, 0, len(tagsResp.Models))
67+
for _, m := range tagsResp.Models {
68+
models = append(models, ModelInfo{
69+
ID: m.Name,
70+
DisplayName: m.Name,
71+
Provider: "ollama",
72+
})
73+
}
74+
75+
return models, nil
76+
}

internal/agent/types.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,9 @@ type ToolResult struct {
2626
Output string
2727
IsError bool
2828
}
29+
30+
type ModelInfo struct {
31+
ID string `json:"id"`
32+
DisplayName string `json:"display_name"`
33+
Provider string `json:"provider"`
34+
}

internal/web/server.go

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ var templatesFS embed.FS
2121
// Server represents the web server for the chat interface
2222
type Server struct {
2323
agent *agent.Agent
24+
providers []agent.LLMProvider
2425
templates *template.Template
2526
eventChan chan SSEvent
2627
ctx context.Context
@@ -33,14 +34,15 @@ type SSEvent struct {
3334
}
3435

3536
// NewServer creates a new web server instance
36-
func NewServer(ctx context.Context, ag *agent.Agent) (*Server, error) {
37+
func NewServer(ctx context.Context, ag *agent.Agent, providers []agent.LLMProvider) (*Server, error) {
3738
tmpl, err := template.ParseFS(templatesFS, "templates/*.html")
3839
if err != nil {
3940
return nil, fmt.Errorf("failed to parse templates: %w", err)
4041
}
4142

4243
return &Server{
4344
agent: ag,
45+
providers: providers,
4446
templates: tmpl,
4547
eventChan: make(chan SSEvent, 100),
4648
ctx: ctx,
@@ -57,6 +59,9 @@ func (s *Server) Start(addr string) error {
5759
mux.HandleFunc("POST /tool/deny", s.handleToolDeny)
5860
mux.HandleFunc("GET /events", s.handleSSE)
5961
mux.HandleFunc("POST /reset", s.handleReset)
62+
mux.HandleFunc("GET /api/models", s.handleListModels)
63+
mux.HandleFunc("GET /api/model", s.handleGetModel)
64+
mux.HandleFunc("POST /api/model", s.handleSetModel)
6065

6166
server := &http.Server{
6267
Addr: addr,
@@ -260,6 +265,65 @@ func (s *Server) send(event SSEvent) {
260265
}
261266
}
262267

268+
func (s *Server) handleListModels(w http.ResponseWriter, r *http.Request) {
269+
var allModels []agent.ModelInfo
270+
271+
for _, p := range s.providers {
272+
models, err := p.ListModels(r.Context())
273+
if err != nil {
274+
log.Printf("[MODELS] Failed to list %s models: %v", p.Name(), err)
275+
continue
276+
}
277+
allModels = append(allModels, models...)
278+
}
279+
280+
w.Header().Set("Content-Type", "application/json")
281+
if err := json.NewEncoder(w).Encode(allModels); err != nil {
282+
log.Printf("[ERROR] Failed to encode models response: %v", err)
283+
}
284+
}
285+
286+
func (s *Server) handleGetModel(w http.ResponseWriter, r *http.Request) {
287+
w.Header().Set("Content-Type", "application/json")
288+
resp := map[string]string{"model": s.agent.GetModel()}
289+
if err := json.NewEncoder(w).Encode(resp); err != nil {
290+
log.Printf("[ERROR] Failed to encode model response: %v", err)
291+
}
292+
}
293+
294+
func (s *Server) handleSetModel(w http.ResponseWriter, r *http.Request) {
295+
r.Body = http.MaxBytesReader(w, r.Body, 1<<20)
296+
297+
var req struct {
298+
Provider string `json:"provider"`
299+
Model string `json:"model"`
300+
}
301+
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
302+
http.Error(w, "Invalid request body", http.StatusBadRequest)
303+
return
304+
}
305+
if req.Model == "" {
306+
http.Error(w, "Model is required", http.StatusBadRequest)
307+
return
308+
}
309+
310+
for _, p := range s.providers {
311+
if p.Name() == req.Provider {
312+
s.agent.SetModel(p, req.Model)
313+
log.Printf("[MODEL] Switched to %s/%s", req.Provider, req.Model)
314+
315+
w.Header().Set("Content-Type", "application/json")
316+
resp := map[string]string{"model": req.Model, "provider": req.Provider}
317+
if err := json.NewEncoder(w).Encode(resp); err != nil {
318+
log.Printf("[ERROR] Failed to encode set model response: %v", err)
319+
}
320+
return
321+
}
322+
}
323+
324+
http.Error(w, fmt.Sprintf("Unknown provider: %s", req.Provider), http.StatusBadRequest)
325+
}
326+
263327
func (s *Server) renderMessage(role, content, toolID string, tool *agent.ToolRequest) string {
264328
data := map[string]interface{}{
265329
"Role": role,

0 commit comments

Comments
 (0)