Skip to content

Commit 0a96a7f

Browse files
authored
fix(responses): fix subtle bugs in non-function tool calling (llamastack#3817)
We were generating "FunctionToolCall" items even for MCP (and file-search, etc.) server-side calls. ID mismatches, etc. galore.
1 parent d709eeb commit 0a96a7f

12 files changed

+10660
-51
lines changed

llama_stack/providers/inline/agents/meta_reference/responses/streaming.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,11 @@
4444
OpenAIResponseObjectStreamResponseRefusalDone,
4545
OpenAIResponseOutput,
4646
OpenAIResponseOutputMessageContentOutputText,
47+
OpenAIResponseOutputMessageFileSearchToolCall,
4748
OpenAIResponseOutputMessageFunctionToolCall,
49+
OpenAIResponseOutputMessageMCPCall,
4850
OpenAIResponseOutputMessageMCPListTools,
51+
OpenAIResponseOutputMessageWebSearchToolCall,
4952
OpenAIResponseText,
5053
OpenAIResponseUsage,
5154
OpenAIResponseUsageInputTokensDetails,
@@ -177,6 +180,7 @@ async def create_response(self) -> AsyncIterator[OpenAIResponseObjectStream]:
177180
# (some providers don't support non-empty response_format when tools are present)
178181
response_format = None if self.ctx.response_format.type == "text" else self.ctx.response_format
179182
logger.debug(f"calling openai_chat_completion with tools: {self.ctx.chat_tools}")
183+
180184
params = OpenAIChatCompletionRequestWithExtraBody(
181185
model=self.ctx.model,
182186
messages=messages,
@@ -613,19 +617,22 @@ async def _process_streaming_chunks(
613617

614618
# Emit output_item.added event for the new function call
615619
self.sequence_number += 1
616-
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
617-
arguments="", # Will be filled incrementally via delta events
618-
call_id=tool_call.id or "",
619-
name=tool_call.function.name if tool_call.function else "",
620-
id=tool_call_item_id,
621-
status="in_progress",
622-
)
623-
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
624-
response_id=self.response_id,
625-
item=function_call_item,
626-
output_index=len(output_messages),
627-
sequence_number=self.sequence_number,
628-
)
620+
is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server
621+
if not is_mcp_tool and tool_call.function.name not in ["web_search", "knowledge_search"]:
622+
# for MCP tools (and even other non-function tools) we emit an output message item later
623+
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
624+
arguments="", # Will be filled incrementally via delta events
625+
call_id=tool_call.id or "",
626+
name=tool_call.function.name if tool_call.function else "",
627+
id=tool_call_item_id,
628+
status="in_progress",
629+
)
630+
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
631+
response_id=self.response_id,
632+
item=function_call_item,
633+
output_index=len(output_messages),
634+
sequence_number=self.sequence_number,
635+
)
629636

630637
# Stream tool call arguments as they arrive (differentiate between MCP and function calls)
631638
if tool_call.function and tool_call.function.arguments:
@@ -806,6 +813,35 @@ async def _coordinate_tool_execution(
806813
if not matching_item_id:
807814
matching_item_id = f"tc_{uuid.uuid4()}"
808815

816+
self.sequence_number += 1
817+
if tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server:
818+
item = OpenAIResponseOutputMessageMCPCall(
819+
arguments="",
820+
name=tool_call.function.name,
821+
id=matching_item_id,
822+
server_label=self.mcp_tool_to_server[tool_call.function.name].server_label,
823+
status="in_progress",
824+
)
825+
elif tool_call.function.name == "web_search":
826+
item = OpenAIResponseOutputMessageWebSearchToolCall(
827+
id=matching_item_id,
828+
status="in_progress",
829+
)
830+
elif tool_call.function.name == "knowledge_search":
831+
item = OpenAIResponseOutputMessageFileSearchToolCall(
832+
id=matching_item_id,
833+
status="in_progress",
834+
)
835+
else:
836+
raise ValueError(f"Unsupported tool call: {tool_call.function.name}")
837+
838+
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
839+
response_id=self.response_id,
840+
item=item,
841+
output_index=len(output_messages),
842+
sequence_number=self.sequence_number,
843+
)
844+
809845
# Execute tool call with streaming
810846
tool_call_log = None
811847
tool_response_message = None
@@ -1064,7 +1100,11 @@ async def _add_mcp_list_tools(
10641100
self.sequence_number += 1
10651101
yield OpenAIResponseObjectStreamResponseOutputItemAdded(
10661102
response_id=self.response_id,
1067-
item=mcp_list_message,
1103+
item=OpenAIResponseOutputMessageMCPListTools(
1104+
id=mcp_list_message.id,
1105+
server_label=mcp_list_message.server_label,
1106+
tools=[],
1107+
),
10681108
output_index=len(output_messages) - 1,
10691109
sequence_number=self.sequence_number,
10701110
)

llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ async def execute_tool_call(
9393

9494
# Build result messages from tool execution
9595
output_message, input_message = await self._build_result_messages(
96-
function, tool_call_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
96+
function, tool_call_id, item_id, tool_kwargs, ctx, error_exc, result, has_error, mcp_tool_to_server
9797
)
9898

9999
# Yield the final result
@@ -356,6 +356,7 @@ async def _build_result_messages(
356356
self,
357357
function,
358358
tool_call_id: str,
359+
item_id: str,
359360
tool_kwargs: dict,
360361
ctx: ChatCompletionContext,
361362
error_exc: Exception | None,
@@ -375,7 +376,7 @@ async def _build_result_messages(
375376
)
376377

377378
message = OpenAIResponseOutputMessageMCPCall(
378-
id=tool_call_id,
379+
id=item_id,
379380
arguments=function.arguments,
380381
name=function.name,
381382
server_label=mcp_tool_to_server[function.name].server_label,
@@ -389,14 +390,14 @@ async def _build_result_messages(
389390
else:
390391
if function.name == "web_search":
391392
message = OpenAIResponseOutputMessageWebSearchToolCall(
392-
id=tool_call_id,
393+
id=item_id,
393394
status="completed",
394395
)
395396
if has_error:
396397
message.status = "failed"
397398
elif function.name == "knowledge_search":
398399
message = OpenAIResponseOutputMessageFileSearchToolCall(
399-
id=tool_call_id,
400+
id=item_id,
400401
queries=[tool_kwargs.get("query", "")],
401402
status="completed",
402403
)

0 commit comments

Comments
 (0)