Skip to content

Commit 8d675f4

Browse files
filintodyaron2
andauthored
Conversation Tool Calling fix on echo and other minor fixes (#3930)
Signed-off-by: Filinto Duran <[email protected]> Co-authored-by: Yaron Schneider <[email protected]>
1 parent 20d7345 commit 8d675f4

File tree

6 files changed

+174
-170
lines changed

6 files changed

+174
-170
lines changed

conversation/deepseek/deepseek.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ func (d *Deepseek) Init(ctx context.Context, meta conversation.Metadata) error {
5959
model = md.Model
6060
}
6161

62+
if md.Endpoint == "" {
63+
md.Endpoint = defaultEndpoint
64+
}
65+
6266
options := []openai.Option{
6367
openai.WithModel(model),
6468
openai.WithToken(md.Key),

conversation/echo/echo.go

Lines changed: 76 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import (
1818
"context"
1919
"fmt"
2020
"reflect"
21+
"strconv"
22+
"strings"
2123

2224
"github.com/tmc/langchaingo/llms"
2325

@@ -61,67 +63,104 @@ func (e *Echo) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
6163

6264
// Converse returns one output per input message.
6365
func (e *Echo) Converse(ctx context.Context, r *conversation.Request) (res *conversation.Response, err error) {
64-
if r.Message == nil {
66+
if r == nil || r.Message == nil {
6567
return &conversation.Response{
6668
ConversationContext: r.ConversationContext,
6769
Outputs: []conversation.Result{},
6870
}, nil
6971
}
7072

71-
outputs := make([]conversation.Result, 0, len(*r.Message))
73+
// if we get tools, respond with tool calls for each tool
74+
var toolCalls []llms.ToolCall
75+
if r.Tools != nil {
76+
// create tool calls for each tool
77+
toolCalls = make([]llms.ToolCall, 0, len(*r.Tools))
78+
for id, tool := range *r.Tools {
79+
// extract argument names from parameters.properties
80+
if tool.Function == nil {
81+
continue // skip if no function
82+
}
83+
// try to get parameters/arg-names from tool function if any
84+
var parameters map[string]any
85+
var argNames []string
86+
if tool.Function.Parameters != nil {
87+
// ensure parameters are a map
88+
ok := false
89+
parameters, ok = tool.Function.Parameters.(map[string]any)
90+
if !ok {
91+
return nil, fmt.Errorf("tool function parameters must be a map[string]any, got %T", tool.Function.Parameters)
92+
}
93+
}
94+
// try get arg names from properties
95+
if properties, ok := parameters["properties"]; ok {
96+
_, ok = properties.(map[string]any)
97+
if !ok {
98+
return nil, fmt.Errorf("tool function properties must be a map[string]any, got %T", properties)
99+
}
100+
if propMap, ok := properties.(map[string]any); ok && len(propMap) != 0 {
101+
argNames = make([]string, 0, len(propMap))
102+
for argName := range propMap {
103+
argNames = append(argNames, argName)
104+
}
105+
}
106+
}
72107

73-
for _, message := range *r.Message {
74-
var content string
75-
var toolCalls []llms.ToolCall
108+
toolCalls = append(toolCalls, llms.ToolCall{
109+
ID: strconv.Itoa(id),
110+
Type: tool.Type,
111+
FunctionCall: &llms.FunctionCall{
112+
Name: tool.Function.Name,
113+
Arguments: strings.Join(argNames, ","),
114+
},
115+
})
116+
}
117+
}
76118

77-
for i, part := range message.Parts {
119+
// iterate over each message in the request to echo back the content in the response. We respond with the acummulated content of the message parts and tool responses
120+
contentFromMessaged := make([]string, 0, len(*r.Message))
121+
for _, message := range *r.Message {
122+
for _, part := range message.Parts {
78123
switch p := part.(type) {
79124
case llms.TextContent:
80-
// end with space if not the first part
81-
if i > 0 && content != "" {
82-
content += " "
83-
}
84-
content += p.Text
125+
// append to slice that we'll join later with new line separators
126+
contentFromMessaged = append(contentFromMessaged, p.Text)
85127
case llms.ToolCall:
128+
// in case we added explicit tool calls on the request like on multi-turn conversations. We still append tool calls for each tool defined in the request.
86129
toolCalls = append(toolCalls, p)
87130
case llms.ToolCallResponse:
88-
content = p.Content
89-
toolCalls = append(toolCalls, llms.ToolCall{
90-
ID: p.ToolCallID,
91-
Type: "function",
92-
FunctionCall: &llms.FunctionCall{
93-
Name: p.Name,
94-
Arguments: p.Content,
95-
},
96-
})
131+
// show tool responses on the request like on multi-turn conversations
132+
contentFromMessaged = append(contentFromMessaged, fmt.Sprintf("Tool Response for tool ID '%s' with name '%s': %s", p.ToolCallID, p.Name, p.Content))
97133
default:
98134
return nil, fmt.Errorf("found invalid content type as input for %v", p)
99135
}
100136
}
137+
}
101138

102-
choice := conversation.Choice{
103-
FinishReason: "stop",
104-
Index: 0,
105-
Message: conversation.Message{
106-
Content: content,
107-
},
108-
}
109-
110-
if len(toolCalls) > 0 {
111-
choice.Message.ToolCallRequest = &toolCalls
112-
}
139+
stopReason := "stop"
140+
if len(toolCalls) > 0 {
141+
stopReason = "tool_calls"
142+
// follows openai spec for tool_calls finish reason https://platform.openai.com/docs/api-reference/chat/object
143+
}
144+
choice := conversation.Choice{
145+
FinishReason: stopReason,
146+
Index: 0,
147+
Message: conversation.Message{
148+
Content: strings.Join(contentFromMessaged, "\n"),
149+
},
150+
}
113151

114-
output := conversation.Result{
115-
StopReason: "stop",
116-
Choices: []conversation.Choice{choice},
117-
}
152+
if len(toolCalls) > 0 {
153+
choice.Message.ToolCallRequest = &toolCalls
154+
}
118155

119-
outputs = append(outputs, output)
156+
output := conversation.Result{
157+
StopReason: stopReason,
158+
Choices: []conversation.Choice{choice},
120159
}
121160

122161
res = &conversation.Response{
123162
ConversationContext: r.ConversationContext,
124-
Outputs: outputs,
163+
Outputs: []conversation.Result{output},
125164
}
126165

127166
return res, nil

conversation/echo/echo_test.go

Lines changed: 42 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -97,19 +97,7 @@ func TestConverse(t *testing.T) {
9797
FinishReason: "stop",
9898
Index: 0,
9999
Message: conversation.Message{
100-
Content: "first message second message",
101-
},
102-
},
103-
},
104-
},
105-
{
106-
StopReason: "stop",
107-
Choices: []conversation.Choice{
108-
{
109-
FinishReason: "stop",
110-
Index: 0,
111-
Message: conversation.Message{
112-
Content: "third message",
100+
Content: "first message\nsecond message\nthird message",
113101
},
114102
},
115103
},
@@ -127,7 +115,7 @@ func TestConverse(t *testing.T) {
127115
Message: &tt.inputs,
128116
})
129117
require.NoError(t, err)
130-
assert.Len(t, r.Outputs, len(tt.expected.Outputs))
118+
assert.Len(t, r.Outputs, 1)
131119
assert.Equal(t, tt.expected.Outputs, r.Outputs)
132120
})
133121
}
@@ -137,20 +125,32 @@ func TestConverseAlpha2(t *testing.T) {
137125
tests := []struct {
138126
name string
139127
messages []llms.MessageContent
128+
tools []llms.Tool
140129
expected *conversation.Response
141130
}{
142131
{
143132
name: "tool call request",
144133
messages: []llms.MessageContent{
145134
{
146-
Role: llms.ChatMessageTypeAI,
135+
Role: llms.ChatMessageTypeHuman,
147136
Parts: []llms.ContentPart{
148-
llms.ToolCall{
149-
ID: "myid",
150-
Type: "function",
151-
FunctionCall: &llms.FunctionCall{
152-
Name: "myfunc",
153-
Arguments: `{"name": "Dapr"}`,
137+
llms.TextContent{Text: "hello echo"},
138+
},
139+
},
140+
},
141+
tools: []llms.Tool{
142+
{
143+
Type: "function",
144+
Function: &llms.FunctionDefinition{
145+
Name: "myfunc",
146+
Description: "A function that does something",
147+
Parameters: map[string]any{
148+
"type": "object",
149+
"properties": map[string]any{
150+
"name": map[string]any{
151+
"type": "string",
152+
"description": "The name to process",
153+
},
154154
},
155155
},
156156
},
@@ -159,19 +159,20 @@ func TestConverseAlpha2(t *testing.T) {
159159
expected: &conversation.Response{
160160
Outputs: []conversation.Result{
161161
{
162-
StopReason: "stop",
162+
StopReason: "tool_calls",
163163
Choices: []conversation.Choice{
164164
{
165-
FinishReason: "stop",
165+
FinishReason: "tool_calls",
166166
Index: 0,
167167
Message: conversation.Message{
168+
Content: "hello echo",
168169
ToolCallRequest: &[]llms.ToolCall{
169170
{
170-
ID: "myid",
171+
ID: "0", // ID is auto-generated by the echo component
171172
Type: "function",
172173
FunctionCall: &llms.FunctionCall{
173174
Name: "myfunc",
174-
Arguments: `{"name": "Dapr"}`,
175+
Arguments: "name",
175176
},
176177
},
177178
},
@@ -183,8 +184,15 @@ func TestConverseAlpha2(t *testing.T) {
183184
},
184185
},
185186
{
186-
name: "tool call response",
187+
name: "text message with tool call response",
188+
// echo responds with the text message and tool call response appended to the message content
187189
messages: []llms.MessageContent{
190+
{
191+
Role: llms.ChatMessageTypeHuman,
192+
Parts: []llms.ContentPart{
193+
llms.TextContent{Text: "hello echo"},
194+
},
195+
},
188196
{
189197
Role: llms.ChatMessageTypeTool,
190198
Parts: []llms.ContentPart{
@@ -205,62 +213,7 @@ func TestConverseAlpha2(t *testing.T) {
205213
FinishReason: "stop",
206214
Index: 0,
207215
Message: conversation.Message{
208-
Content: "Dapr",
209-
ToolCallRequest: &[]llms.ToolCall{
210-
{
211-
ID: "myid",
212-
Type: "function",
213-
FunctionCall: &llms.FunctionCall{
214-
Name: "myfunc",
215-
Arguments: "Dapr",
216-
},
217-
},
218-
},
219-
},
220-
},
221-
},
222-
},
223-
},
224-
},
225-
},
226-
{
227-
name: "mixed content with text and tool call",
228-
messages: []llms.MessageContent{
229-
{
230-
Role: llms.ChatMessageTypeAI,
231-
Parts: []llms.ContentPart{
232-
llms.TextContent{Text: "text msg"},
233-
llms.ToolCall{
234-
ID: "myid",
235-
Type: "function",
236-
FunctionCall: &llms.FunctionCall{
237-
Name: "myfunc",
238-
Arguments: `{"name": "Dapr"}`,
239-
},
240-
},
241-
},
242-
},
243-
},
244-
expected: &conversation.Response{
245-
Outputs: []conversation.Result{
246-
{
247-
StopReason: "stop",
248-
Choices: []conversation.Choice{
249-
{
250-
FinishReason: "stop",
251-
Index: 0,
252-
Message: conversation.Message{
253-
Content: "text msg",
254-
ToolCallRequest: &[]llms.ToolCall{
255-
{
256-
ID: "myid",
257-
Type: "function",
258-
FunctionCall: &llms.FunctionCall{
259-
Name: "myfunc",
260-
Arguments: `{"name": "Dapr"}`,
261-
},
262-
},
263-
},
216+
Content: "hello echo\nTool Response for tool ID 'myid' with name 'myfunc': Dapr",
264217
},
265218
},
266219
},
@@ -275,9 +228,14 @@ func TestConverseAlpha2(t *testing.T) {
275228
e := NewEcho(logger.NewLogger("echo test"))
276229
e.Init(t.Context(), conversation.Metadata{})
277230

278-
r, err := e.Converse(t.Context(), &conversation.Request{
231+
request := &conversation.Request{
279232
Message: &tt.messages,
280-
})
233+
}
234+
if len(tt.tools) > 0 {
235+
request.Tools = &tt.tools
236+
}
237+
238+
r, err := e.Converse(t.Context(), request)
281239
require.NoError(t, err)
282240

283241
assert.Len(t, r.Outputs, 1)
@@ -295,7 +253,6 @@ func TestConverseAlpha2(t *testing.T) {
295253

296254
for j, toolCall := range *choice.Message.ToolCallRequest {
297255
expectedToolCall := (*expectedChoice.Message.ToolCallRequest)[j]
298-
assert.Equal(t, expectedToolCall.ID, toolCall.ID)
299256
assert.Equal(t, expectedToolCall.Type, toolCall.Type)
300257

301258
if expectedToolCall.FunctionCall != nil {

0 commit comments

Comments
 (0)