Skip to content

Commit c5c7311

Browse files
committed
feat(webui): support realtime provider/fallback routing with persistence
1 parent 6424022 commit c5c7311

File tree

8 files changed

+391
-12
lines changed

8 files changed

+391
-12
lines changed

pkg/agent/agent.go

Lines changed: 178 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package agent
33
import (
44
"context"
55
"fmt"
6+
"strings"
67
"time"
78

89
"go.uber.org/zap"
@@ -123,22 +124,39 @@ func (a *Agent) RegisterSkillTool(skillsManager *skills.Manager) {
123124
// Chat processes a user message and returns the agent's response.
124125
// It handles tool calls and iterates until the agent produces a final response.
125126
func (a *Agent) Chat(ctx context.Context, sess SessionInterface, userMessage string) (string, error) {
126-
return a.chatWithModel(ctx, sess, userMessage, a.config.Agents.Defaults.Model)
127+
return a.chatWithProviderModel(ctx, sess, userMessage, "", a.config.Agents.Defaults.Model, nil)
127128
}
128129

129130
// ChatWithModel processes a message using a specific model override.
130131
func (a *Agent) ChatWithModel(ctx context.Context, sess SessionInterface, userMessage, model string) (string, error) {
131-
return a.chatWithModel(ctx, sess, userMessage, model)
132+
return a.chatWithProviderModel(ctx, sess, userMessage, "", model, nil)
132133
}
133134

134-
func (a *Agent) chatWithModel(ctx context.Context, sess SessionInterface, userMessage, model string) (string, error) {
135+
// ChatWithProviderModel processes a message using provider/model overrides.
136+
func (a *Agent) ChatWithProviderModel(ctx context.Context, sess SessionInterface, userMessage, provider, model string) (string, error) {
137+
return a.chatWithProviderModel(ctx, sess, userMessage, provider, model, nil)
138+
}
139+
140+
// ChatWithProviderModelAndFallback processes a message using provider/model/fallback overrides.
141+
func (a *Agent) ChatWithProviderModelAndFallback(ctx context.Context, sess SessionInterface, userMessage, provider, model string, fallback []string) (string, error) {
142+
return a.chatWithProviderModel(ctx, sess, userMessage, provider, model, fallback)
143+
}
144+
145+
func (a *Agent) chatWithProviderModel(ctx context.Context, sess SessionInterface, userMessage, provider, model string, fallback []string) (string, error) {
135146
a.logger.Info("Processing chat message",
136147
zap.String("message", truncate(userMessage, 100)),
137148
)
138149
if model == "" {
139150
model = a.config.Agents.Defaults.Model
140151
}
141152

153+
providerOrder, err := a.buildProviderOrder(provider, fallback)
154+
if err != nil {
155+
return "", err
156+
}
157+
primaryProvider := providerOrder[0]
158+
clientCache := make(map[string]*providers.Client)
159+
142160
// Build initial messages with session history
143161
history := sess.GetMessages() // Get messages from session
144162
messages := a.context.BuildMessages(history, userMessage)
@@ -176,13 +194,15 @@ func (a *Agent) chatWithModel(ctx context.Context, sess SessionInterface, userMe
176194
}
177195
}
178196

179-
// Call LLM
180-
resp, err := a.client.Chat(ctx, req)
197+
// Call LLM with provider fallback.
198+
resp, providerUsed, modelUsed, err := a.callLLMWithFallback(ctx, req, primaryProvider, providerOrder, model, clientCache)
181199
if err != nil {
182200
return "", fmt.Errorf("LLM call failed: %w", err)
183201
}
184202

185203
a.logger.Debug("LLM response",
204+
zap.String("provider", providerUsed),
205+
zap.String("model", modelUsed),
186206
zap.String("content", truncate(resp.Content, 100)),
187207
zap.Int("tool_calls", len(resp.ToolCalls)),
188208
zap.String("finish_reason", resp.FinishReason),
@@ -236,6 +256,159 @@ func (a *Agent) chatWithModel(ctx context.Context, sess SessionInterface, userMe
236256
return "", fmt.Errorf("max iterations (%d) reached without final response", a.maxIterations)
237257
}
238258

259+
func (a *Agent) newClientForProvider(providerName, model string) (*providers.Client, error) {
260+
providerCfg := a.config.GetProviderConfig(providerName)
261+
if providerCfg == nil {
262+
return nil, fmt.Errorf("provider not found: %s", providerName)
263+
}
264+
265+
providerKind := strings.TrimSpace(providerCfg.ProviderKind)
266+
if providerKind == "" {
267+
providerKind = providerName
268+
}
269+
270+
client, err := providers.NewClient(providerKind, &providers.RelayInfo{
271+
ProviderName: providerName,
272+
APIKey: providerCfg.APIKey,
273+
APIBase: providerCfg.APIBase,
274+
Model: model,
275+
Proxy: providerCfg.Proxy,
276+
Timeout: providerCfg.GetTimeout(),
277+
})
278+
if err != nil {
279+
return nil, fmt.Errorf("create provider client for %s: %w", providerName, err)
280+
}
281+
282+
return client, nil
283+
}
284+
285+
func (a *Agent) buildProviderOrder(provider string, fallback []string) ([]string, error) {
286+
primary := strings.TrimSpace(provider)
287+
if primary == "" {
288+
primary = strings.TrimSpace(a.config.Agents.Defaults.Provider)
289+
}
290+
291+
fallbackOrder := fallback
292+
if len(fallbackOrder) == 0 {
293+
fallbackOrder = a.config.Agents.Defaults.Fallback
294+
}
295+
296+
seen := make(map[string]struct{})
297+
order := make([]string, 0, 1+len(fallbackOrder))
298+
299+
addProvider := func(name string) {
300+
trimmed := strings.TrimSpace(name)
301+
if trimmed == "" {
302+
return
303+
}
304+
if _, ok := seen[trimmed]; ok {
305+
return
306+
}
307+
seen[trimmed] = struct{}{}
308+
order = append(order, trimmed)
309+
}
310+
311+
addProvider(primary)
312+
for _, name := range fallbackOrder {
313+
addProvider(name)
314+
}
315+
316+
if len(order) == 0 && len(a.config.Providers) > 0 {
317+
addProvider(a.config.Providers[0].Name)
318+
}
319+
320+
if len(order) == 0 {
321+
return nil, fmt.Errorf("no providers configured")
322+
}
323+
324+
return order, nil
325+
}
326+
327+
func (a *Agent) callLLMWithFallback(
328+
ctx context.Context,
329+
req *providers.UnifiedRequest,
330+
primaryProvider string,
331+
providerOrder []string,
332+
requestedModel string,
333+
clientCache map[string]*providers.Client,
334+
) (*providers.UnifiedResponse, string, string, error) {
335+
var lastErr error
336+
337+
for _, providerName := range providerOrder {
338+
model := a.resolveModelForProvider(providerName, primaryProvider, requestedModel)
339+
340+
client, err := a.getProviderClient(providerName, model, clientCache)
341+
if err != nil {
342+
lastErr = err
343+
a.logger.Warn("Provider unavailable", zap.String("provider", providerName), zap.Error(err))
344+
continue
345+
}
346+
347+
reqCopy := *req
348+
reqCopy.Model = model
349+
350+
resp, err := client.Chat(ctx, &reqCopy)
351+
if err != nil {
352+
lastErr = err
353+
a.logger.Warn("Provider request failed", zap.String("provider", providerName), zap.String("model", model), zap.Error(err))
354+
continue
355+
}
356+
357+
return resp, providerName, model, nil
358+
}
359+
360+
if lastErr == nil {
361+
lastErr = fmt.Errorf("no provider attempt made")
362+
}
363+
return nil, "", "", lastErr
364+
}
365+
366+
func (a *Agent) getProviderClient(providerName, model string, cache map[string]*providers.Client) (*providers.Client, error) {
367+
key := providerName + "::" + model
368+
if client, ok := cache[key]; ok {
369+
return client, nil
370+
}
371+
372+
client, err := a.newClientForProvider(providerName, model)
373+
if err != nil {
374+
return nil, err
375+
}
376+
cache[key] = client
377+
return client, nil
378+
}
379+
380+
func (a *Agent) resolveModelForProvider(providerName, primaryProvider, requestedModel string) string {
381+
model := strings.TrimSpace(requestedModel)
382+
if model == "" {
383+
model = strings.TrimSpace(a.config.Agents.Defaults.Model)
384+
}
385+
if providerName == primaryProvider {
386+
return model
387+
}
388+
389+
providerCfg := a.config.GetProviderConfig(providerName)
390+
if providerCfg == nil {
391+
return model
392+
}
393+
394+
// If this provider declares no model list, keep caller's model.
395+
if len(providerCfg.Models) == 0 {
396+
return model
397+
}
398+
399+
for _, candidate := range providerCfg.Models {
400+
if strings.TrimSpace(candidate) == model {
401+
return model
402+
}
403+
}
404+
405+
if fallbackModel := strings.TrimSpace(providerCfg.GetDefaultModel()); fallbackModel != "" {
406+
return fallbackModel
407+
}
408+
409+
return model
410+
}
411+
239412
// executeToolCall executes a single tool call with approval checking.
240413
func (a *Agent) executeToolCall(ctx context.Context, toolCall providers.UnifiedToolCall) (string, error) {
241414
a.logger.Info("Executing tool",

pkg/agent/agent_test.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package agent
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
7+
"nekobot/pkg/config"
8+
)
9+
10+
func TestBuildProviderOrder_UsesOverrideAndFallback(t *testing.T) {
11+
cfg := config.DefaultConfig()
12+
cfg.Agents.Defaults.Provider = "anthropic"
13+
cfg.Agents.Defaults.Fallback = []string{"openai", "ollama"}
14+
cfg.Providers = []config.ProviderProfile{
15+
{Name: "anthropic", ProviderKind: "anthropic"},
16+
{Name: "openai", ProviderKind: "openai"},
17+
{Name: "ollama", ProviderKind: "openai"},
18+
}
19+
20+
ag := &Agent{config: cfg}
21+
22+
got, err := ag.buildProviderOrder("openai", []string{"ollama", "openai", "anthropic"})
23+
if err != nil {
24+
t.Fatalf("buildProviderOrder failed: %v", err)
25+
}
26+
27+
want := []string{"openai", "ollama", "anthropic"}
28+
if !reflect.DeepEqual(got, want) {
29+
t.Fatalf("expected provider order %v, got %v", want, got)
30+
}
31+
}
32+
33+
func TestBuildProviderOrder_UsesConfigDefaultsWhenRequestFallbackEmpty(t *testing.T) {
34+
cfg := config.DefaultConfig()
35+
cfg.Agents.Defaults.Provider = "anthropic"
36+
cfg.Agents.Defaults.Fallback = []string{"openai"}
37+
cfg.Providers = []config.ProviderProfile{
38+
{Name: "anthropic", ProviderKind: "anthropic"},
39+
{Name: "openai", ProviderKind: "openai"},
40+
}
41+
42+
ag := &Agent{config: cfg}
43+
44+
got, err := ag.buildProviderOrder("", nil)
45+
if err != nil {
46+
t.Fatalf("buildProviderOrder failed: %v", err)
47+
}
48+
49+
want := []string{"anthropic", "openai"}
50+
if !reflect.DeepEqual(got, want) {
51+
t.Fatalf("expected provider order %v, got %v", want, got)
52+
}
53+
}
54+
55+
func TestResolveModelForProvider_FallsBackToProviderDefaultModel(t *testing.T) {
56+
cfg := config.DefaultConfig()
57+
cfg.Agents.Defaults.Model = "claude-sonnet-4-5-20250929"
58+
cfg.Providers = []config.ProviderProfile{
59+
{
60+
Name: "anthropic",
61+
ProviderKind: "anthropic",
62+
Models: []string{"claude-sonnet-4-5-20250929"},
63+
DefaultModel: "claude-sonnet-4-5-20250929",
64+
},
65+
{
66+
Name: "openai",
67+
ProviderKind: "openai",
68+
Models: []string{"gpt-4o-mini"},
69+
DefaultModel: "gpt-4o-mini",
70+
},
71+
}
72+
73+
ag := &Agent{config: cfg}
74+
75+
got := ag.resolveModelForProvider("openai", "anthropic", "claude-sonnet-4-5-20250929")
76+
want := "gpt-4o-mini"
77+
if got != want {
78+
t.Fatalf("expected model %q, got %q", want, got)
79+
}
80+
}

0 commit comments

Comments
 (0)