@@ -3,6 +3,7 @@ package agent
33import (
44 "context"
55 "fmt"
6+ "strings"
67 "time"
78
89 "go.uber.org/zap"
@@ -123,22 +124,39 @@ func (a *Agent) RegisterSkillTool(skillsManager *skills.Manager) {
123124// Chat processes a user message and returns the agent's response.
124125// It handles tool calls and iterates until the agent produces a final response.
125126func (a * Agent ) Chat (ctx context.Context , sess SessionInterface , userMessage string ) (string , error ) {
126- return a .chatWithModel (ctx , sess , userMessage , a .config .Agents .Defaults .Model )
127+ return a .chatWithProviderModel (ctx , sess , userMessage , "" , a .config .Agents .Defaults .Model , nil )
127128}
128129
129130// ChatWithModel processes a message using a specific model override.
130131func (a * Agent ) ChatWithModel (ctx context.Context , sess SessionInterface , userMessage , model string ) (string , error ) {
131- return a .chatWithModel (ctx , sess , userMessage , model )
132+ return a .chatWithProviderModel (ctx , sess , userMessage , "" , model , nil )
132133}
133134
134- func (a * Agent ) chatWithModel (ctx context.Context , sess SessionInterface , userMessage , model string ) (string , error ) {
135+ // ChatWithProviderModel processes a message using provider/model overrides.
136+ func (a * Agent ) ChatWithProviderModel (ctx context.Context , sess SessionInterface , userMessage , provider , model string ) (string , error ) {
137+ return a .chatWithProviderModel (ctx , sess , userMessage , provider , model , nil )
138+ }
139+
140+ // ChatWithProviderModelAndFallback processes a message using provider/model/fallback overrides.
141+ func (a * Agent ) ChatWithProviderModelAndFallback (ctx context.Context , sess SessionInterface , userMessage , provider , model string , fallback []string ) (string , error ) {
142+ return a .chatWithProviderModel (ctx , sess , userMessage , provider , model , fallback )
143+ }
144+
145+ func (a * Agent ) chatWithProviderModel (ctx context.Context , sess SessionInterface , userMessage , provider , model string , fallback []string ) (string , error ) {
135146 a .logger .Info ("Processing chat message" ,
136147 zap .String ("message" , truncate (userMessage , 100 )),
137148 )
138149 if model == "" {
139150 model = a .config .Agents .Defaults .Model
140151 }
141152
153+ providerOrder , err := a .buildProviderOrder (provider , fallback )
154+ if err != nil {
155+ return "" , err
156+ }
157+ primaryProvider := providerOrder [0 ]
158+ clientCache := make (map [string ]* providers.Client )
159+
142160 // Build initial messages with session history
143161 history := sess .GetMessages () // Get messages from session
144162 messages := a .context .BuildMessages (history , userMessage )
@@ -176,13 +194,15 @@ func (a *Agent) chatWithModel(ctx context.Context, sess SessionInterface, userMe
176194 }
177195 }
178196
179- // Call LLM
180- resp , err := a .client . Chat (ctx , req )
197+ // Call LLM with provider fallback.
198+ resp , providerUsed , modelUsed , err := a .callLLMWithFallback (ctx , req , primaryProvider , providerOrder , model , clientCache )
181199 if err != nil {
182200 return "" , fmt .Errorf ("LLM call failed: %w" , err )
183201 }
184202
185203 a .logger .Debug ("LLM response" ,
204+ zap .String ("provider" , providerUsed ),
205+ zap .String ("model" , modelUsed ),
186206 zap .String ("content" , truncate (resp .Content , 100 )),
187207 zap .Int ("tool_calls" , len (resp .ToolCalls )),
188208 zap .String ("finish_reason" , resp .FinishReason ),
@@ -236,6 +256,159 @@ func (a *Agent) chatWithModel(ctx context.Context, sess SessionInterface, userMe
236256 return "" , fmt .Errorf ("max iterations (%d) reached without final response" , a .maxIterations )
237257}
238258
259+ func (a * Agent ) newClientForProvider (providerName , model string ) (* providers.Client , error ) {
260+ providerCfg := a .config .GetProviderConfig (providerName )
261+ if providerCfg == nil {
262+ return nil , fmt .Errorf ("provider not found: %s" , providerName )
263+ }
264+
265+ providerKind := strings .TrimSpace (providerCfg .ProviderKind )
266+ if providerKind == "" {
267+ providerKind = providerName
268+ }
269+
270+ client , err := providers .NewClient (providerKind , & providers.RelayInfo {
271+ ProviderName : providerName ,
272+ APIKey : providerCfg .APIKey ,
273+ APIBase : providerCfg .APIBase ,
274+ Model : model ,
275+ Proxy : providerCfg .Proxy ,
276+ Timeout : providerCfg .GetTimeout (),
277+ })
278+ if err != nil {
279+ return nil , fmt .Errorf ("create provider client for %s: %w" , providerName , err )
280+ }
281+
282+ return client , nil
283+ }
284+
285+ func (a * Agent ) buildProviderOrder (provider string , fallback []string ) ([]string , error ) {
286+ primary := strings .TrimSpace (provider )
287+ if primary == "" {
288+ primary = strings .TrimSpace (a .config .Agents .Defaults .Provider )
289+ }
290+
291+ fallbackOrder := fallback
292+ if len (fallbackOrder ) == 0 {
293+ fallbackOrder = a .config .Agents .Defaults .Fallback
294+ }
295+
296+ seen := make (map [string ]struct {})
297+ order := make ([]string , 0 , 1 + len (fallbackOrder ))
298+
299+ addProvider := func (name string ) {
300+ trimmed := strings .TrimSpace (name )
301+ if trimmed == "" {
302+ return
303+ }
304+ if _ , ok := seen [trimmed ]; ok {
305+ return
306+ }
307+ seen [trimmed ] = struct {}{}
308+ order = append (order , trimmed )
309+ }
310+
311+ addProvider (primary )
312+ for _ , name := range fallbackOrder {
313+ addProvider (name )
314+ }
315+
316+ if len (order ) == 0 && len (a .config .Providers ) > 0 {
317+ addProvider (a .config .Providers [0 ].Name )
318+ }
319+
320+ if len (order ) == 0 {
321+ return nil , fmt .Errorf ("no providers configured" )
322+ }
323+
324+ return order , nil
325+ }
326+
327+ func (a * Agent ) callLLMWithFallback (
328+ ctx context.Context ,
329+ req * providers.UnifiedRequest ,
330+ primaryProvider string ,
331+ providerOrder []string ,
332+ requestedModel string ,
333+ clientCache map [string ]* providers.Client ,
334+ ) (* providers.UnifiedResponse , string , string , error ) {
335+ var lastErr error
336+
337+ for _ , providerName := range providerOrder {
338+ model := a .resolveModelForProvider (providerName , primaryProvider , requestedModel )
339+
340+ client , err := a .getProviderClient (providerName , model , clientCache )
341+ if err != nil {
342+ lastErr = err
343+ a .logger .Warn ("Provider unavailable" , zap .String ("provider" , providerName ), zap .Error (err ))
344+ continue
345+ }
346+
347+ reqCopy := * req
348+ reqCopy .Model = model
349+
350+ resp , err := client .Chat (ctx , & reqCopy )
351+ if err != nil {
352+ lastErr = err
353+ a .logger .Warn ("Provider request failed" , zap .String ("provider" , providerName ), zap .String ("model" , model ), zap .Error (err ))
354+ continue
355+ }
356+
357+ return resp , providerName , model , nil
358+ }
359+
360+ if lastErr == nil {
361+ lastErr = fmt .Errorf ("no provider attempt made" )
362+ }
363+ return nil , "" , "" , lastErr
364+ }
365+
366+ func (a * Agent ) getProviderClient (providerName , model string , cache map [string ]* providers.Client ) (* providers.Client , error ) {
367+ key := providerName + "::" + model
368+ if client , ok := cache [key ]; ok {
369+ return client , nil
370+ }
371+
372+ client , err := a .newClientForProvider (providerName , model )
373+ if err != nil {
374+ return nil , err
375+ }
376+ cache [key ] = client
377+ return client , nil
378+ }
379+
380+ func (a * Agent ) resolveModelForProvider (providerName , primaryProvider , requestedModel string ) string {
381+ model := strings .TrimSpace (requestedModel )
382+ if model == "" {
383+ model = strings .TrimSpace (a .config .Agents .Defaults .Model )
384+ }
385+ if providerName == primaryProvider {
386+ return model
387+ }
388+
389+ providerCfg := a .config .GetProviderConfig (providerName )
390+ if providerCfg == nil {
391+ return model
392+ }
393+
394+ // If this provider declares no model list, keep caller's model.
395+ if len (providerCfg .Models ) == 0 {
396+ return model
397+ }
398+
399+ for _ , candidate := range providerCfg .Models {
400+ if strings .TrimSpace (candidate ) == model {
401+ return model
402+ }
403+ }
404+
405+ if fallbackModel := strings .TrimSpace (providerCfg .GetDefaultModel ()); fallbackModel != "" {
406+ return fallbackModel
407+ }
408+
409+ return model
410+ }
411+
239412// executeToolCall executes a single tool call with approval checking.
240413func (a * Agent ) executeToolCall (ctx context.Context , toolCall providers.UnifiedToolCall ) (string , error ) {
241414 a .logger .Info ("Executing tool" ,
0 commit comments