Skip to content

Commit 2f1d10e

Browse files
committed
feat: improve AgenticToolChoice (#684)
1 parent d358aa3 commit 2f1d10e

File tree

4 files changed

+65
-14
lines changed

4 files changed

+65
-14
lines changed

components/model/option.go

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ type Options struct {
2828
TopP *float32
2929
// Tools is a list of tools the model may call.
3030
Tools []*schema.ToolInfo
31-
// ToolChoice controls which tool is called by the model.
32-
ToolChoice *schema.ToolChoice
3331

34-
// Options only for chat model.
32+
// Options only available for chat model.
3533

34+
// ToolChoice controls which tool is called by the model.
35+
ToolChoice *schema.ToolChoice
3636
// MaxTokens is the max number of tokens, if reached the max tokens, the model will stop generating, and mostly return an finish reason of "length".
3737
MaxTokens *int
3838
// AllowedToolNames specifies a list of tool names that the model is allowed to call.
@@ -41,10 +41,10 @@ type Options struct {
4141
// Stop is the stop words for the model, which controls the stopping condition of the model.
4242
Stop []string
4343

44-
// Options only for agentic model.
44+
// Options only available for agentic model.
4545

46-
// AllowedTools is a list of allowed tools the model may call.
47-
AllowedTools []*schema.AllowedTool
46+
// AgenticToolChoice controls how the agentic model calls tools.
47+
AgenticToolChoice *schema.AgenticToolChoice
4848
}
4949

5050
// Option is the call option for ChatModel component.
@@ -127,11 +127,10 @@ func WithToolChoice(toolChoice schema.ToolChoice, allowedToolNames ...string) Op
127127

128128
// WithAgenticToolChoice is the option to set tool choice for the agentic model.
129129
// Only available for AgenticModel.
130-
func WithAgenticToolChoice(toolChoice schema.ToolChoice, allowedTools ...*schema.AllowedTool) Option {
130+
func WithAgenticToolChoice(toolChoice *schema.AgenticToolChoice) Option {
131131
return Option{
132132
apply: func(opts *Options) {
133-
opts.ToolChoice = &toolChoice
134-
opts.AllowedTools = allowedTools
133+
opts.AgenticToolChoice = toolChoice
135134
},
136135
}
137136
}

components/model/option_test.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,11 +92,18 @@ func TestOptions(t *testing.T) {
9292
)
9393
opts := GetCommonOptions(
9494
nil,
95-
WithAgenticToolChoice(toolChoice, allowedTools...),
95+
WithAgenticToolChoice(&schema.AgenticToolChoice{
96+
Type: toolChoice,
97+
Forced: &schema.AgenticForcedToolChoice{
98+
Tools: allowedTools,
99+
},
100+
}),
96101
)
97102

98-
convey.So(opts.ToolChoice, convey.ShouldResemble, &toolChoice)
99-
convey.So(opts.AllowedTools, convey.ShouldResemble, allowedTools)
103+
convey.So(opts.AgenticToolChoice, convey.ShouldNotBeNil)
104+
convey.So(opts.AgenticToolChoice.Type, convey.ShouldEqual, toolChoice)
105+
convey.So(opts.AgenticToolChoice.Forced, convey.ShouldNotBeNil)
106+
convey.So(opts.AgenticToolChoice.Forced.Tools, convey.ShouldResemble, allowedTools)
100107
})
101108
}
102109

compose/workflow.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,18 +89,36 @@ func (wf *Workflow[I, O]) AddChatModelNode(key string, chatModel model.BaseChatM
8989
return wf.initNode(key)
9090
}
9191

92+
// AddAgenticModelNode adds an agentic model node and returns it.
93+
func (wf *Workflow[I, O]) AddAgenticModelNode(key string, agenticModel model.AgenticModel, opts ...GraphAddNodeOpt) *WorkflowNode {
94+
_ = wf.g.AddAgenticModelNode(key, agenticModel, opts...)
95+
return wf.initNode(key)
96+
}
97+
9298
// AddChatTemplateNode adds a chat template node and returns it.
9399
func (wf *Workflow[I, O]) AddChatTemplateNode(key string, chatTemplate prompt.ChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode {
94100
_ = wf.g.AddChatTemplateNode(key, chatTemplate, opts...)
95101
return wf.initNode(key)
96102
}
97103

104+
// AddAgenticChatTemplateNode adds an agentic chat template node and returns it.
105+
func (wf *Workflow[I, O]) AddAgenticChatTemplateNode(key string, chatTemplate prompt.AgenticChatTemplate, opts ...GraphAddNodeOpt) *WorkflowNode {
106+
_ = wf.g.AddAgenticChatTemplateNode(key, chatTemplate, opts...)
107+
return wf.initNode(key)
108+
}
109+
98110
// AddToolsNode adds a tools node and returns it.
99111
func (wf *Workflow[I, O]) AddToolsNode(key string, tools *ToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode {
100112
_ = wf.g.AddToolsNode(key, tools, opts...)
101113
return wf.initNode(key)
102114
}
103115

116+
// AddAgenticToolsNode adds an agentic tools node and returns it.
117+
func (wf *Workflow[I, O]) AddAgenticToolsNode(key string, tools *AgenticToolsNode, opts ...GraphAddNodeOpt) *WorkflowNode {
118+
_ = wf.g.AddAgenticToolsNode(key, tools, opts...)
119+
return wf.initNode(key)
120+
}
121+
104122
// AddRetrieverNode adds a retriever node and returns it.
105123
func (wf *Workflow[I, O]) AddRetrieverNode(key string, retriever retriever.Retriever, opts ...GraphAddNodeOpt) *WorkflowNode {
106124
_ = wf.g.AddRetrieverNode(key, retriever, opts...)

schema/tool.go

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,31 @@ const (
5555
ToolChoiceForced ToolChoice = "forced"
5656
)
5757

58+
type AgenticToolChoice struct {
59+
// Type is the tool choice mode.
60+
Type ToolChoice
61+
// Allowed optionally specifies the list of tools that the model is permitted to call.
62+
Allowed *AgenticAllowedToolChoice
63+
// Forced optionally specifies the list of tools that the model is required to call.
64+
Forced *AgenticForcedToolChoice
65+
}
66+
67+
// AgenticAllowedToolChoice specifies a list of allowed tools for the model.
68+
type AgenticAllowedToolChoice struct {
69+
// Tools is the list of allowed tools for the model to call.
70+
// Optional.
71+
Tools []*AllowedTool
72+
}
73+
74+
// AgenticForcedToolChoice specifies a list of tools that the model must call.
75+
type AgenticForcedToolChoice struct {
76+
// Tools is the list of tools that the model must call.
77+
// Optional.
78+
Tools []*AllowedTool
79+
}
80+
81+
// AllowedTool represents a tool that the model is allowed or forced to call.
82+
// Exactly one of FunctionToolName, MCPTool, or ServerTool must be specified.
5883
type AllowedTool struct {
5984
// FunctionToolName is the name of the function tool.
6085
FunctionToolName string
@@ -64,15 +89,17 @@ type AllowedTool struct {
6489
ServerTool *AllowedServerTool
6590
}
6691

92+
// AllowedMCPTool contains the information for identifying an MCP tool.
6793
type AllowedMCPTool struct {
6894
// ServerLabel is the label of the MCP server.
6995
ServerLabel string
70-
// The name of the MCP tool.
96+
// Name is the name of the MCP tool.
7197
Name string
7298
}
7399

100+
// AllowedServerTool contains the information for identifying a server tool.
74101
type AllowedServerTool struct {
75-
// The name of the server tool.
102+
// Name is the name of the server tool.
76103
Name string
77104
}
78105

0 commit comments

Comments
 (0)