Skip to content

Commit bdd9d22

Browse files
tools: fix parsing issue when a tool name is a substring of another (ollama#11456)
Co-authored-by: frob <[email protected]>
1 parent 5fc38d0 commit bdd9d22

File tree

2 files changed

+310
-17
lines changed

2 files changed

+310
-17
lines changed

tools/tools.go

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -115,21 +115,7 @@ func (p *Parser) findTag() (int, bool) {
115115
// parseToolCall finds the next complete tool call in the buffer
116116
// incrementing n and advancing the buffer.
117117
func (p *Parser) parseToolCall() *api.ToolCall {
118-
var tool *api.Tool
119-
var end int = len(p.buffer)
120-
var i int
121-
122-
// find tool name
123-
for _, t := range p.tools {
124-
n := t.Function.Name
125-
if i = bytes.Index(p.buffer, []byte(n)); i != -1 {
126-
if i+len(n) < end {
127-
tool = &t
128-
end = i + len(n)
129-
}
130-
}
131-
}
132-
118+
tool, end := findTool(p.tools, p.buffer)
133119
if tool == nil {
134120
return nil
135121
}
@@ -139,10 +125,10 @@ func (p *Parser) parseToolCall() *api.ToolCall {
139125
// parsing arguments before the tool name, which may be needed in the future
140126
args := map[string]any{}
141127
if len(tool.Function.Parameters.Properties) > 0 {
128+
var i int
142129
if args, i = findArguments(*tool, p.buffer[end:]); args == nil {
143130
return nil
144131
}
145-
146132
end += i
147133
}
148134

@@ -159,9 +145,74 @@ func (p *Parser) parseToolCall() *api.ToolCall {
159145
return tc
160146
}
161147

148+
// findTool finds the first tool name in the list that matches the
149+
// beginning of the buffer, returning nil if no tool is found
150+
// or if the buffer ends with a partial tool name since we need
151+
// to wait for more data to disambiguate.
152+
// The second return value is the end position of the tool name
153+
// if one is found, otherwise 0.
154+
func findTool(tools []api.Tool, buf []byte) (*api.Tool, int) {
155+
if len(buf) == 0 {
156+
return nil, 0
157+
}
158+
159+
// check if buffer ends with a partial tool name
160+
// this prevents matching "get" when seeing "get_weather"
161+
var longest string
162+
for _, t := range tools {
163+
if len(t.Function.Name) > len(longest) {
164+
longest = t.Function.Name
165+
}
166+
}
167+
168+
// Only check up to longest characters from the end
169+
for i := 1; i <= min(len(buf), len(longest)); i++ {
170+
tail := buf[len(buf)-i:]
171+
for _, t := range tools {
172+
name := []byte(t.Function.Name)
173+
if len(tail) < len(name) && bytes.HasPrefix(name, tail) {
174+
return nil, 0
175+
}
176+
}
177+
}
178+
179+
// find first occurrence of the longest tool name
180+
var found *api.Tool
181+
start := -1
182+
end := -1
183+
184+
for i := range tools {
185+
name := []byte(tools[i].Function.Name)
186+
pos := bytes.Index(buf, name)
187+
if pos == -1 {
188+
continue
189+
}
190+
191+
// Skip if we have a better match already
192+
if start != -1 {
193+
if pos > start {
194+
continue
195+
}
196+
if pos == start && len(name) <= len(found.Function.Name) {
197+
continue
198+
}
199+
}
200+
201+
found = &tools[i]
202+
start = pos
203+
end = pos + len(name)
204+
}
205+
206+
if found != nil {
207+
return found, end
208+
}
209+
210+
return nil, 0
211+
}
212+
162213
// findArguments returns the first object that appears to be
163214
// arguments for the provided tool in the provided buffer,
164-
// returning nil if no arguments are found.
215+
// returning nil if no arguments are found and the end position
165216
// TODO (jmorganca): this does not support parsing omitted arguments
166217
// objects for functions that have all-optional parameters
167218
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but

tools/tools_test.go

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,81 @@ func TestParser(t *testing.T) {
112112
Description: "Say hello",
113113
},
114114
},
115+
{
116+
Type: "function",
117+
Function: api.ToolFunction{
118+
Name: "say_hello_world",
119+
Description: "Say hello world",
120+
},
121+
},
122+
{
123+
Type: "function",
124+
Function: api.ToolFunction{
125+
Name: "get_address",
126+
Description: "Get the address of a given location",
127+
Parameters: struct {
128+
Type string `json:"type"`
129+
Defs any `json:"$defs,omitempty"`
130+
Items any `json:"items,omitempty"`
131+
Required []string `json:"required"`
132+
Properties map[string]struct {
133+
Type api.PropertyType `json:"type"`
134+
Items any `json:"items,omitempty"`
135+
Description string `json:"description"`
136+
Enum []any `json:"enum,omitempty"`
137+
} `json:"properties"`
138+
}{
139+
Type: "object",
140+
Properties: map[string]struct {
141+
Type api.PropertyType `json:"type"`
142+
Items any `json:"items,omitempty"`
143+
Description string `json:"description"`
144+
Enum []any `json:"enum,omitempty"`
145+
}{
146+
"location": {
147+
Type: api.PropertyType{"string"},
148+
Description: "The location to get the address for",
149+
},
150+
},
151+
},
152+
},
153+
},
154+
{
155+
Type: "function",
156+
Function: api.ToolFunction{
157+
Name: "add",
158+
Description: "Add two numbers",
159+
Parameters: struct {
160+
Type string `json:"type"`
161+
Defs any `json:"$defs,omitempty"`
162+
Items any `json:"items,omitempty"`
163+
Required []string `json:"required"`
164+
Properties map[string]struct {
165+
Type api.PropertyType `json:"type"`
166+
Items any `json:"items,omitempty"`
167+
Description string `json:"description"`
168+
Enum []any `json:"enum,omitempty"`
169+
} `json:"properties"`
170+
}{
171+
Type: "object",
172+
Properties: map[string]struct {
173+
Type api.PropertyType `json:"type"`
174+
Items any `json:"items,omitempty"`
175+
Description string `json:"description"`
176+
Enum []any `json:"enum,omitempty"`
177+
}{
178+
"a": {
179+
Type: api.PropertyType{"string"},
180+
Description: "The first number to add",
181+
},
182+
"b": {
183+
Type: api.PropertyType{"string"},
184+
Description: "The second number to add",
185+
},
186+
},
187+
},
188+
},
189+
},
115190
}
116191

117192
tests := []struct {
@@ -629,6 +704,173 @@ func TestParser(t *testing.T) {
629704
},
630705
},
631706
},
707+
{
708+
name: "tool name with collision",
709+
inputs: []string{
710+
"<tool_call>",
711+
"{",
712+
"\"name\": \"say_hello",
713+
"_world\",",
714+
"}",
715+
"}",
716+
},
717+
content: "",
718+
tmpl: qwen,
719+
calls: []api.ToolCall{
720+
{
721+
Function: api.ToolCallFunction{
722+
Index: 0,
723+
Name: "say_hello_world",
724+
Arguments: api.ToolCallFunctionArguments{},
725+
},
726+
},
727+
},
728+
},
729+
{
730+
name: "tool name with collision multiple",
731+
inputs: []string{
732+
"<tool_call>",
733+
"{",
734+
"\"name\": \"say_hello",
735+
"_world\",",
736+
"}",
737+
"</tool_call>",
738+
"<tool_call>",
739+
"{",
740+
"\"name\": \"say_hello",
741+
"\",",
742+
"}",
743+
"</tool_call>",
744+
},
745+
content: "",
746+
tmpl: qwen,
747+
calls: []api.ToolCall{
748+
{
749+
Function: api.ToolCallFunction{
750+
Index: 0,
751+
Name: "say_hello_world",
752+
Arguments: api.ToolCallFunctionArguments{},
753+
},
754+
},
755+
{
756+
Function: api.ToolCallFunction{
757+
Index: 1,
758+
Name: "say_hello",
759+
Arguments: api.ToolCallFunctionArguments{},
760+
},
761+
},
762+
},
763+
},
764+
{
765+
name: "tool name with collision non streaming",
766+
inputs: []string{
767+
`<tool_call>{"name": "say_hello`,
768+
},
769+
content: "",
770+
tmpl: qwen,
771+
calls: nil,
772+
},
773+
{
774+
name: "tool name with collision non streaming multiple",
775+
inputs: []string{
776+
`<tool_call>{"name": "say_hello"}</tool_call><tool_call>{"name": "say_hello_world"}`,
777+
},
778+
content: "",
779+
tmpl: qwen,
780+
calls: []api.ToolCall{
781+
{
782+
Function: api.ToolCallFunction{
783+
Index: 0,
784+
Name: "say_hello",
785+
Arguments: api.ToolCallFunctionArguments{},
786+
},
787+
},
788+
{
789+
Function: api.ToolCallFunction{
790+
Index: 1,
791+
Name: "say_hello_world",
792+
Arguments: api.ToolCallFunctionArguments{},
793+
},
794+
},
795+
},
796+
},
797+
{
798+
name: "tool name with collision non streaming shorter",
799+
inputs: []string{
800+
`<tool_call>{"name": "say_hello"}</tool_call>`,
801+
},
802+
content: "",
803+
tmpl: qwen,
804+
calls: []api.ToolCall{
805+
{
806+
Function: api.ToolCallFunction{
807+
Index: 0,
808+
Name: "say_hello",
809+
Arguments: api.ToolCallFunctionArguments{},
810+
},
811+
},
812+
},
813+
},
814+
{
815+
name: "tool name with collision non streaming longer",
816+
inputs: []string{
817+
`<tool_call>{"name": "say_hello_world"}</tool_call>`,
818+
},
819+
content: "",
820+
tmpl: qwen,
821+
calls: []api.ToolCall{
822+
{
823+
Function: api.ToolCallFunction{
824+
Index: 0,
825+
Name: "say_hello_world",
826+
Arguments: api.ToolCallFunctionArguments{},
827+
},
828+
},
829+
},
830+
},
831+
{
832+
name: "tool name with substring of another",
833+
inputs: []string{
834+
"{",
835+
"\"name\": \"get_address\",",
836+
"\"arguments\": {",
837+
"\"location\": \"London\"",
838+
"}",
839+
"}",
840+
},
841+
content: "",
842+
tmpl: json,
843+
calls: []api.ToolCall{
844+
{
845+
Function: api.ToolCallFunction{
846+
Index: 0,
847+
Name: "get_address",
848+
Arguments: api.ToolCallFunctionArguments{
849+
"location": "London",
850+
},
851+
},
852+
},
853+
},
854+
},
855+
{
856+
name: "tool name with substring of another",
857+
inputs: []string{
858+
`<tool_call>{"name": "get_address", "arguments": {"location": "London"}}</tool_call>`,
859+
},
860+
content: "",
861+
tmpl: qwen,
862+
calls: []api.ToolCall{
863+
{
864+
Function: api.ToolCallFunction{
865+
Index: 0,
866+
Name: "get_address",
867+
Arguments: api.ToolCallFunctionArguments{
868+
"location": "London",
869+
},
870+
},
871+
},
872+
},
873+
},
632874
}
633875

634876
for _, tt := range tests {

0 commit comments

Comments
 (0)