Skip to content

Commit 4d8c64a

Browse files
authored
Merge pull request #11 from OriginTrail/fix/multitool-call
fix: multi-tool calls do not wait for all calls to finish
2 parents 0901bac + 793191d commit 4d8c64a

File tree

1 file changed

+107
-35
lines changed

1 file changed

+107
-35
lines changed

apps/agent/src/app/(protected)/chat.tsx

Lines changed: 107 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ export default function ChatPage() {
5050
const [messages, setMessages] = useState<ChatMessage[]>([]);
5151
const [isGenerating, setIsGenerating] = useState(false);
5252

53+
const pendingToolCalls = useRef<Set<string>>(new Set()); // Track tool calls that need responses before calling LLM
54+
const toolKAContents = useRef<Map<string, any[]>>(new Map()); // Track KAs across tool calls in a single request
55+
5356
const chatMessagesRef = useRef<ScrollView>(null);
5457

5558
async function callTool(tc: ToolCall & { id: string }) {
@@ -67,7 +70,7 @@ export default function ChatPage() {
6770
output: result.content,
6871
});
6972

70-
return sendMessage({
73+
addToolResultAndCheckCompletion({
7174
role: "tool",
7275
tool_call_id: tc.id,
7376
content: result.content as ToolCallResultContent,
@@ -80,7 +83,7 @@ export default function ChatPage() {
8083
error: err.message,
8184
});
8285

83-
return sendMessage({
86+
addToolResultAndCheckCompletion({
8487
role: "tool",
8588
tool_call_id: tc.id,
8689
content: "Error occurred while calling tool: " + err.message,
@@ -89,52 +92,119 @@ export default function ChatPage() {
8992
});
9093
}
9194

95+
function addToolResultAndCheckCompletion(toolResult: ChatMessage) {
96+
const kaContents: any[] = [];
97+
const otherContents: any[] = [];
98+
99+
for (const c of toContents(toolResult.content) as ToolCallResultContent) {
100+
const kas = parseSourceKAContent(c);
101+
if (kas) kaContents.push(c);
102+
else otherContents.push(c);
103+
}
104+
toolResult.content = otherContents;
105+
106+
const toolCallId = (toolResult as any).tool_call_id;
107+
if (kaContents.length > 0) {
108+
toolKAContents.current.set(toolCallId, kaContents);
109+
}
110+
111+
setMessages((prevMessages) => [...prevMessages, toolResult]);
112+
pendingToolCalls.current.delete(toolCallId);
113+
114+
if (pendingToolCalls.current.size === 0) {
115+
requestCompletion(); // If all tool calls are complete, only then hit the LLM
116+
}
117+
}
118+
119+
async function requestCompletion() {
120+
if (!mcp.token) throw new Error("Unauthorized");
121+
122+
setIsGenerating(true);
123+
try {
124+
let currentMessages: ChatMessage[] = [];
125+
await new Promise<void>((resolve) => {
126+
setMessages((prevMessages) => {
127+
currentMessages = prevMessages;
128+
resolve();
129+
return prevMessages;
130+
});
131+
});
132+
133+
const completion = await makeCompletionRequest(
134+
{
135+
messages: currentMessages,
136+
tools: tools.enabled,
137+
},
138+
{
139+
fetch: (url, opts) => fetch(url.toString(), opts as any) as any,
140+
bearerToken: mcp.token,
141+
},
142+
);
143+
144+
const allKAContents: any[] = [];
145+
toolKAContents.current.forEach((kaContents) => {
146+
allKAContents.push(...kaContents);
147+
});
148+
149+
if (allKAContents.length > 0) {
150+
completion.content = toContents(completion.content);
151+
completion.content.push(...allKAContents);
152+
}
153+
154+
toolKAContents.current.clear();
155+
156+
setMessages((prevMessages) => [...prevMessages, completion]);
157+
158+
if (completion.tool_calls && completion.tool_calls.length > 0) {
159+
completion.tool_calls.forEach((tc: any) => {
160+
pendingToolCalls.current.add(tc.id);
161+
});
162+
}
163+
} finally {
164+
setIsGenerating(false);
165+
setTimeout(() => chatMessagesRef.current?.scrollToEnd(), 100);
166+
}
167+
}
168+
92169
async function cancelToolCall(tc: ToolCall & { id: string }) {
93170
tools.saveCallInfo(tc.id, { input: tc.args, status: "cancelled" });
94171

95-
return sendMessage({
172+
addToolResultAndCheckCompletion({
96173
role: "tool",
97174
tool_call_id: tc.id,
98175
content: "Tool call was cancelled by user",
99176
});
100177
}
101178

102179
async function sendMessage(newMessage: ChatMessage) {
103-
const kaContents: any[] = [];
104-
if (newMessage.role === "tool") {
105-
const otherContents: any[] = [];
106-
for (const c of toContents(newMessage.content) as ToolCallResultContent) {
107-
const kas = parseSourceKAContent(c);
108-
if (kas) kaContents.push(c);
109-
else otherContents.push(c);
110-
}
111-
newMessage.content = otherContents;
112-
}
113-
114180
setMessages((prevMessages) => [...prevMessages, newMessage]);
115181

116182
if (!mcp.token) throw new Error("Unauthorized");
117183

118184
setIsGenerating(true);
119-
const completion = await makeCompletionRequest(
120-
{
121-
messages: [...messages, newMessage],
122-
tools: tools.enabled,
123-
},
124-
{
125-
fetch: (url, opts) => fetch(url.toString(), opts as any) as any,
126-
bearerToken: mcp.token,
127-
},
128-
);
129-
130-
if (newMessage.role === "tool") {
131-
completion.content = toContents(completion.content);
132-
completion.content.push(...kaContents);
185+
try {
186+
const completion = await makeCompletionRequest(
187+
{
188+
messages: [...messages, newMessage],
189+
tools: tools.enabled,
190+
},
191+
{
192+
fetch: (url, opts) => fetch(url.toString(), opts as any) as any,
193+
bearerToken: mcp.token,
194+
},
195+
);
196+
197+
setMessages((prevMessages) => [...prevMessages, completion]);
198+
199+
if (completion.tool_calls && completion.tool_calls.length > 0) {
200+
completion.tool_calls.forEach((tc: any) => {
201+
pendingToolCalls.current.add(tc.id);
202+
});
203+
}
204+
} finally {
205+
setIsGenerating(false);
206+
setTimeout(() => chatMessagesRef.current?.scrollToEnd(), 100);
133207
}
134-
135-
setMessages((prevMessages) => [...prevMessages, completion]);
136-
setIsGenerating(false);
137-
setTimeout(() => chatMessagesRef.current?.scrollToEnd(), 100);
138208
}
139209

140210
const kaResolver = useCallback<SourceKAResolver>(
@@ -151,8 +221,8 @@ export default function ChatPage() {
151221
parsedContent.metadata
152222
.at(0)
153223
?.[
154-
"https://ontology.origintrail.io/dkg/1.0#publishTime"
155-
]?.at(0)?.["@value"] ?? Date.now(),
224+
"https://ontology.origintrail.io/dkg/1.0#publishTime"
225+
]?.at(0)?.["@value"] ?? Date.now(),
156226
).getTime(),
157227
txHash: parsedContent.metadata
158228
.at(0)
@@ -184,7 +254,7 @@ export default function ChatPage() {
184254
parsedContent.metadata
185255
.at(0)
186256
?.["https://ontology.origintrail.io/dkg/1.0#publishTx"]?.at(0)?.[
187-
"@value"
257+
"@value"
188258
] ?? "unknown";
189259
resolved.publisher =
190260
parsedContent.metadata
@@ -360,6 +430,8 @@ export default function ChatPage() {
360430
onStartAgain={() => {
361431
setMessages([]);
362432
tools.reset();
433+
pendingToolCalls.current.clear();
434+
toolKAContents.current.clear();
363435
}}
364436
/>
365437
)}

0 commit comments

Comments
 (0)