Skip to content

Commit 78958d8

Browse files
committed
added tools to middleware interface
1 parent 5cfc80d commit 78958d8

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

go/ai/generate.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,14 @@ func Generate(ctx context.Context, r api.Registry, opts ...GenerateOption) (*Mod
500500
return nil, err
501501
}
502502

503+
// Collect tools provided by middleware.
504+
for _, mw := range genOpts.Use {
505+
for _, t := range mw.Tools() {
506+
dynamicTools = append(dynamicTools, t)
507+
toolNames = append(toolNames, t.Name())
508+
}
509+
}
510+
503511
if len(dynamicTools) > 0 {
504512
if !r.IsChild() {
505513
r = r.NewChild()

go/ai/middleware.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ type Middleware interface {
4040
Model(ctx context.Context, state *ModelState, next ModelNext) (*ModelResponse, error)
4141
// Tool wraps each tool execution.
4242
Tool(ctx context.Context, state *ToolState, next ToolNext) (*ToolResponse, error)
43+
// Tools returns additional tools to make available during generation.
44+
// These tools are dynamically registered when the middleware is used via [WithUse].
45+
Tools() []Tool
4346
}
4447

4548
// GenerateState holds state for the Generate hook.
@@ -93,6 +96,8 @@ func (b *BaseMiddleware) Tool(ctx context.Context, state *ToolState, next ToolNe
9396
return next(ctx, state)
9497
}
9598

99+
func (b *BaseMiddleware) Tools() []Tool { return nil }
100+
96101
// Register registers the descriptor with the registry.
97102
func (d *MiddlewareDesc) Register(r api.Registry) {
98103
r.RegisterValue("/middleware/"+d.Name, d)

0 commit comments

Comments
 (0)