diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 80ef068c79..55bf31f575 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -11,6 +11,7 @@ import warnings from collections.abc import AsyncGenerator from datetime import UTC, datetime +from typing import Any import httpx @@ -125,12 +126,12 @@ def __init__( ) def turn_to_messages(self, turn: Turn) -> list[Message]: - messages = [] + messages: list[Message] = [] # NOTE: if a toolcall response is in a step, we do not add it when processing the input messages tool_call_ids = set() for step in turn.steps: - if step.step_type == StepType.tool_execution.value: + if step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep): for response in step.tool_responses: tool_call_ids.add(response.call_id) @@ -149,9 +150,9 @@ def turn_to_messages(self, turn: Turn) -> list[Message]: messages.append(msg) for step in turn.steps: - if step.step_type == StepType.inference.value: + if step.step_type == StepType.inference.value and isinstance(step, InferenceStep): messages.append(step.model_response) - elif step.step_type == StepType.tool_execution.value: + elif step.step_type == StepType.tool_execution.value and isinstance(step, ToolExecutionStep): for response in step.tool_responses: messages.append( ToolResponseMessage( @@ -159,8 +160,8 @@ def turn_to_messages(self, turn: Turn) -> list[Message]: content=response.content, ) ) - elif step.step_type == StepType.shield_call.value: - if step.violation: + elif step.step_type == StepType.shield_call.value and isinstance(step, ShieldCallStep): + if step.violation and step.violation.user_message: # CompletionMessage itself in the ShieldResponse messages.append( CompletionMessage( @@ -174,7 +175,7 @@ async def create_session(self, name: str) -> str: return await self.storage.create_session(name) async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]: - messages = [] + messages: list[Message] = [] if self.agent_config.instructions != "": messages.append(SystemMessage(content=self.agent_config.instructions)) @@ -231,7 +232,9 @@ async def _run_turn( steps = [] messages = await self.get_messages_from_turns(turns) + if is_resume: + assert isinstance(request, AgentTurnResumeRequest) tool_response_messages = [ ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses ] @@ -252,42 +255,52 @@ async def _run_turn( in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( request.session_id, request.turn_id ) - now = datetime.now(UTC).isoformat() + now_dt = datetime.now(UTC) tool_execution_step = ToolExecutionStep( step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), turn_id=request.turn_id, tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), tool_responses=request.tool_responses, - completed_at=now, - started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), + completed_at=now_dt, + started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now_dt), ) steps.append(tool_execution_step) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.tool_execution.value, + step_type=StepType.tool_execution, step_id=tool_execution_step.step_id, step_details=tool_execution_step, ) ) ) - input_messages = last_turn.input_messages + # Cast needed due to list invariance - last_turn.input_messages is the right type + input_messages = last_turn.input_messages # type: ignore[assignment] - turn_id = request.turn_id + actual_turn_id = request.turn_id start_time = last_turn.started_at else: + assert isinstance(request, AgentTurnCreateRequest) messages.extend(request.messages) - start_time = datetime.now(UTC).isoformat() - input_messages = request.messages + start_time = datetime.now(UTC) + # Cast needed due to list invariance - request.messages is the right type + input_messages = request.messages # type: ignore[assignment] + # Use the generated turn_id from beginning of function + actual_turn_id = turn_id if turn_id else str(uuid.uuid4()) output_message = None + req_documents = request.documents if isinstance(request, AgentTurnCreateRequest) and not is_resume else None + req_sampling = ( + self.agent_config.sampling_params if self.agent_config.sampling_params is not None else SamplingParams() + ) + async for chunk in self.run( session_id=request.session_id, - turn_id=turn_id, + turn_id=actual_turn_id, input_messages=messages, - sampling_params=self.agent_config.sampling_params, + sampling_params=req_sampling, stream=request.stream, - documents=request.documents if not is_resume else None, + documents=req_documents, ): if isinstance(chunk, CompletionMessage): output_message = chunk @@ -295,20 +308,23 @@ async def _run_turn( assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" event = chunk.event - if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: - steps.append(event.payload.step_details) + if event.payload.event_type == AgentTurnResponseEventType.step_complete.value and hasattr( + event.payload, "step_details" + ): + step_details = event.payload.step_details + steps.append(step_details) yield chunk assert output_message is not None turn = Turn( - turn_id=turn_id, + turn_id=actual_turn_id, session_id=request.session_id, - input_messages=input_messages, + input_messages=input_messages, # type: ignore[arg-type] output_message=output_message, started_at=start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), steps=steps, ) await self.storage.add_turn_to_session(request.session_id, turn) @@ -345,7 +361,7 @@ async def run( # return a "final value" for the `yield from` statement. we simulate that by yielding a # final boolean (to see whether an exception happened) and then explicitly testing for it. - if len(self.input_shields) > 0: + if self.input_shields: async for res in self.run_multiple_shields_wrapper( turn_id, input_messages, self.input_shields, "user-input" ): @@ -374,7 +390,7 @@ async def run( # for output shields run on the full input and output combination messages = input_messages + [final_response] - if len(self.output_shields) > 0: + if self.output_shields: async for res in self.run_multiple_shields_wrapper( turn_id, messages, self.output_shields, "assistant-output" ): @@ -402,12 +418,12 @@ async def run_multiple_shields_wrapper( return step_id = str(uuid.uuid4()) - shield_call_start_time = datetime.now(UTC).isoformat() + shield_call_start_time = datetime.now(UTC) try: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( - step_type=StepType.shield_call.value, + step_type=StepType.shield_call, step_id=step_id, metadata=dict(touchpoint=touchpoint), ) @@ -419,14 +435,14 @@ async def run_multiple_shields_wrapper( yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.shield_call.value, + step_type=StepType.shield_call, step_id=step_id, step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, violation=e.violation, started_at=shield_call_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ), ) ) @@ -443,14 +459,14 @@ async def run_multiple_shields_wrapper( yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.shield_call.value, + step_type=StepType.shield_call, step_id=step_id, step_details=ShieldCallStep( step_id=step_id, turn_id=turn_id, violation=None, started_at=shield_call_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ), ) ) @@ -496,21 +512,22 @@ async def _run( else: self.tool_name_to_args[tool_name]["vector_store_ids"].append(session_info.vector_store_id) - output_attachments = [] + output_attachments: list[Attachment] = [] n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 # Build a map of custom tools to their definitions for faster lookup client_tools = {} - for tool in self.agent_config.client_tools: - client_tools[tool.name] = tool + if self.agent_config.client_tools: + for tool in self.agent_config.client_tools: + client_tools[tool.name] = tool while True: step_id = str(uuid.uuid4()) - inference_start_time = datetime.now(UTC).isoformat() + inference_start_time = datetime.now(UTC) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, ) ) @@ -538,7 +555,7 @@ def _serialize_nested(value): else: return value - def _add_type(openai_msg: dict) -> OpenAIMessageParam: + def _add_type(openai_msg: Any) -> OpenAIMessageParam: # Serialize any nested Pydantic models to plain dicts openai_msg = _serialize_nested(openai_msg) @@ -588,7 +605,7 @@ def _add_type(openai_msg: dict) -> OpenAIMessageParam: messages=openai_messages, tools=openai_tools if openai_tools else None, tool_choice=tool_choice, - response_format=self.agent_config.response_format, + response_format=self.agent_config.response_format, # type: ignore[arg-type] temperature=temperature, top_p=top_p, max_tokens=max_tokens, @@ -598,7 +615,8 @@ def _add_type(openai_msg: dict) -> OpenAIMessageParam: # Convert OpenAI stream back to Llama Stack format response_stream = convert_openai_chat_completion_stream( - openai_stream, enable_incremental_tool_calls=True + openai_stream, # type: ignore[arg-type] + enable_incremental_tool_calls=True, ) async for chunk in response_stream: @@ -620,7 +638,7 @@ def _add_type(openai_msg: dict) -> OpenAIMessageParam: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, delta=delta, ) @@ -633,7 +651,7 @@ def _add_type(openai_msg: dict) -> OpenAIMessageParam: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, delta=delta, ) @@ -651,7 +669,9 @@ def _add_type(openai_msg: dict) -> OpenAIMessageParam: output_attr = json.dumps( { "content": content, - "tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls], + "tool_calls": [ + json.loads(t.model_dump_json()) for t in tool_calls if isinstance(t, ToolCall) + ], } ) span.set_attribute("output", output_attr) @@ -667,16 +687,18 @@ def _add_type(openai_msg: dict) -> OpenAIMessageParam: if tool_calls: content = "" + # Filter out string tool calls for CompletionMessage (only keep ToolCall objects) + valid_tool_calls = [t for t in tool_calls if isinstance(t, ToolCall)] message = CompletionMessage( content=content, stop_reason=stop_reason, - tool_calls=tool_calls, + tool_calls=valid_tool_calls if valid_tool_calls else None, ) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.inference.value, + step_type=StepType.inference, step_id=step_id, step_details=InferenceStep( # somewhere deep, we are re-assigning message or closing over some @@ -686,13 +708,14 @@ def _add_type(openai_msg: dict) -> OpenAIMessageParam: turn_id=turn_id, model_response=copy.deepcopy(message), started_at=inference_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ), ) ) ) - if n_iter >= self.agent_config.max_infer_iters: + max_iters = self.agent_config.max_infer_iters if self.agent_config.max_infer_iters is not None else 10 + if n_iter >= max_iters: logger.info(f"done with MAX iterations ({n_iter}), exiting.") # NOTE: mark end_of_turn to indicate to client that we are done with the turn # Do not continue the tool call loop after this point @@ -705,14 +728,16 @@ def _add_type(openai_msg: dict) -> OpenAIMessageParam: yield message break - if len(message.tool_calls) == 0: + if not message.tool_calls or len(message.tool_calls) == 0: if stop_reason == StopReason.end_of_turn: # TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS) if len(output_attachments) > 0: if isinstance(message.content, list): - message.content += output_attachments + # List invariance - attachments are compatible at runtime + message.content += output_attachments # type: ignore[arg-type] else: - message.content = [message.content] + output_attachments + # List invariance - attachments are compatible at runtime + message.content = [message.content] + output_attachments # type: ignore[assignment] yield message else: logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") @@ -725,11 +750,12 @@ def _add_type(openai_msg: dict) -> OpenAIMessageParam: non_client_tool_calls = [] # Separate client and non-client tool calls - for tool_call in message.tool_calls: - if tool_call.tool_name in client_tools: - client_tool_calls.append(tool_call) - else: - non_client_tool_calls.append(tool_call) + if message.tool_calls: + for tool_call in message.tool_calls: + if tool_call.tool_name in client_tools: + client_tool_calls.append(tool_call) + else: + non_client_tool_calls.append(tool_call) # Process non-client tool calls first for tool_call in non_client_tool_calls: @@ -737,7 +763,7 @@ def _add_type(openai_msg: dict) -> OpenAIMessageParam: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepStartPayload( - step_type=StepType.tool_execution.value, + step_type=StepType.tool_execution, step_id=step_id, ) ) @@ -746,7 +772,7 @@ def _add_type(openai_msg: dict) -> OpenAIMessageParam: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.tool_execution.value, + step_type=StepType.tool_execution, step_id=step_id, delta=ToolCallDelta( parse_status=ToolCallParseStatus.in_progress, @@ -766,7 +792,7 @@ def _add_type(openai_msg: dict) -> OpenAIMessageParam: if self.telemetry_enabled else {}, ) as span: - tool_execution_start_time = datetime.now(UTC).isoformat() + tool_execution_start_time = datetime.now(UTC) tool_result = await self.execute_tool_call_maybe( session_id, tool_call, @@ -796,14 +822,14 @@ def _add_type(openai_msg: dict) -> OpenAIMessageParam: ) ], started_at=tool_execution_start_time, - completed_at=datetime.now(UTC).isoformat(), + completed_at=datetime.now(UTC), ) # Yield the step completion event yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.tool_execution.value, + step_type=StepType.tool_execution, step_id=step_id, step_details=tool_execution_step, ) @@ -833,7 +859,7 @@ def _add_type(openai_msg: dict) -> OpenAIMessageParam: turn_id=turn_id, tool_calls=client_tool_calls, tool_responses=[], - started_at=datetime.now(UTC).isoformat(), + started_at=datetime.now(UTC), ), ) @@ -868,19 +894,20 @@ async def _initialize_tools( toolgroup_to_args = toolgroup_to_args or {} - tool_name_to_def = {} + tool_name_to_def: dict[str, ToolDefinition] = {} tool_name_to_args = {} - for tool_def in self.agent_config.client_tools: - if tool_name_to_def.get(tool_def.name, None): - raise ValueError(f"Tool {tool_def.name} already exists") + if self.agent_config.client_tools: + for tool_def in self.agent_config.client_tools: + if tool_name_to_def.get(tool_def.name, None): + raise ValueError(f"Tool {tool_def.name} already exists") - # Use input_schema from ToolDef directly - tool_name_to_def[tool_def.name] = ToolDefinition( - tool_name=tool_def.name, - description=tool_def.description, - input_schema=tool_def.input_schema, - ) + # Use input_schema from ToolDef directly + tool_name_to_def[tool_def.name] = ToolDefinition( + tool_name=tool_def.name, + description=tool_def.description, + input_schema=tool_def.input_schema, + ) for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) @@ -908,15 +935,17 @@ async def _initialize_tools( else: identifier = None - if tool_name_to_def.get(identifier, None): - raise ValueError(f"Tool {identifier} already exists") if identifier: - tool_name_to_def[identifier] = ToolDefinition( - tool_name=identifier, + # Convert BuiltinTool to string for dictionary key + identifier_str = identifier.value if isinstance(identifier, BuiltinTool) else identifier + if tool_name_to_def.get(identifier_str, None): + raise ValueError(f"Tool {identifier_str} already exists") + tool_name_to_def[identifier_str] = ToolDefinition( + tool_name=identifier_str, description=tool_def.description, input_schema=tool_def.input_schema, ) - tool_name_to_args[identifier] = toolgroup_to_args.get(toolgroup_name, {}) + tool_name_to_args[identifier_str] = toolgroup_to_args.get(toolgroup_name, {}) self.tool_defs, self.tool_name_to_args = ( list(tool_name_to_def.values()), @@ -1017,7 +1046,7 @@ def _interpret_content_as_attachment( snippet = match.group(1) data = json.loads(snippet) return Attachment( - url=URL(uri="file://" + data["filepath"]), + content=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"], ) diff --git a/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py b/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py index 8e0dc9ecbd..09a161d50b 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/responses/tool_executor.py @@ -7,6 +7,7 @@ import asyncio import json from collections.abc import AsyncIterator +from typing import Any from llama_stack.apis.agents.openai_responses import ( OpenAIResponseInputToolFileSearch, @@ -22,6 +23,7 @@ OpenAIResponseObjectStreamResponseWebSearchCallSearching, OpenAIResponseOutputMessageFileSearchToolCall, OpenAIResponseOutputMessageFileSearchToolCallResults, + OpenAIResponseOutputMessageMCPCall, OpenAIResponseOutputMessageWebSearchToolCall, ) from llama_stack.apis.common.content_types import ( @@ -67,7 +69,7 @@ async def execute_tool_call( ) -> AsyncIterator[ToolExecutionResult]: tool_call_id = tool_call.id function = tool_call.function - tool_kwargs = json.loads(function.arguments) if function.arguments else {} + tool_kwargs = json.loads(function.arguments) if function and function.arguments else {} if not function or not tool_call_id or not function.name: yield ToolExecutionResult(sequence_number=sequence_number) @@ -84,7 +86,16 @@ async def execute_tool_call( error_exc, result = await self._execute_tool(function.name, tool_kwargs, ctx, mcp_tool_to_server) # Emit completion events for tool execution - has_error = error_exc or (result and ((result.error_code and result.error_code > 0) or result.error_message)) + has_error = bool( + error_exc + or ( + result + and ( + ((error_code := getattr(result, "error_code", None)) and error_code > 0) + or getattr(result, "error_message", None) + ) + ) + ) async for event_result in self._emit_completion_events( function.name, ctx, sequence_number, output_index, item_id, has_error, mcp_tool_to_server ): @@ -101,7 +112,9 @@ async def execute_tool_call( sequence_number=sequence_number, final_output_message=output_message, final_input_message=input_message, - citation_files=result.metadata.get("citation_files") if result and result.metadata else None, + citation_files=( + metadata.get("citation_files") if result and (metadata := getattr(result, "metadata", None)) else None + ), ) async def _execute_knowledge_search_via_vector_store( @@ -188,8 +201,9 @@ async def search_single_store(vector_store_id): citation_files[file_id] = filename + # Cast to proper InterleavedContent type (list invariance) return ToolInvocationResult( - content=content_items, + content=content_items, # type: ignore[arg-type] metadata={ "document_ids": [r.file_id for r in search_results], "chunks": [r.content[0].text if r.content else "" for r in search_results], @@ -209,51 +223,60 @@ async def _emit_progress_events( ) -> AsyncIterator[ToolExecutionResult]: """Emit progress events for tool execution start.""" # Emit in_progress event based on tool type (only for tools with specific streaming events) - progress_event = None if mcp_tool_to_server and function_name in mcp_tool_to_server: sequence_number += 1 - progress_event = OpenAIResponseObjectStreamResponseMcpCallInProgress( - item_id=item_id, - output_index=output_index, + yield ToolExecutionResult( + stream_event=OpenAIResponseObjectStreamResponseMcpCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ), sequence_number=sequence_number, ) elif function_name == "web_search": sequence_number += 1 - progress_event = OpenAIResponseObjectStreamResponseWebSearchCallInProgress( - item_id=item_id, - output_index=output_index, + yield ToolExecutionResult( + stream_event=OpenAIResponseObjectStreamResponseWebSearchCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ), sequence_number=sequence_number, ) elif function_name == "knowledge_search": sequence_number += 1 - progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress( - item_id=item_id, - output_index=output_index, + yield ToolExecutionResult( + stream_event=OpenAIResponseObjectStreamResponseFileSearchCallInProgress( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ), sequence_number=sequence_number, ) - if progress_event: - yield ToolExecutionResult(stream_event=progress_event, sequence_number=sequence_number) - # For web search, emit searching event if function_name == "web_search": sequence_number += 1 - searching_event = OpenAIResponseObjectStreamResponseWebSearchCallSearching( - item_id=item_id, - output_index=output_index, + yield ToolExecutionResult( + stream_event=OpenAIResponseObjectStreamResponseWebSearchCallSearching( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ), sequence_number=sequence_number, ) - yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number) # For file search, emit searching event if function_name == "knowledge_search": sequence_number += 1 - searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching( - item_id=item_id, - output_index=output_index, + yield ToolExecutionResult( + stream_event=OpenAIResponseObjectStreamResponseFileSearchCallSearching( + item_id=item_id, + output_index=output_index, + sequence_number=sequence_number, + ), sequence_number=sequence_number, ) - yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number) async def _execute_tool( self, @@ -261,7 +284,7 @@ async def _execute_tool( tool_kwargs: dict, ctx: ChatCompletionContext, mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, - ) -> tuple[Exception | None, any]: + ) -> tuple[Exception | None, Any]: """Execute the tool and return error exception and result.""" error_exc = None result = None @@ -284,9 +307,13 @@ async def _execute_tool( kwargs=tool_kwargs, ) elif function_name == "knowledge_search": - response_file_search_tool = next( - (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), - None, + response_file_search_tool = ( + next( + (t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)), + None, + ) + if ctx.response_tools + else None ) if response_file_search_tool: # Use vector_stores.search API instead of knowledge_search tool @@ -322,35 +349,34 @@ async def _emit_completion_events( mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, ) -> AsyncIterator[ToolExecutionResult]: """Emit completion or failure events for tool execution.""" - completion_event = None - if mcp_tool_to_server and function_name in mcp_tool_to_server: sequence_number += 1 if has_error: - completion_event = OpenAIResponseObjectStreamResponseMcpCallFailed( + mcp_failed_event = OpenAIResponseObjectStreamResponseMcpCallFailed( sequence_number=sequence_number, ) + yield ToolExecutionResult(stream_event=mcp_failed_event, sequence_number=sequence_number) else: - completion_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( + mcp_completed_event = OpenAIResponseObjectStreamResponseMcpCallCompleted( sequence_number=sequence_number, ) + yield ToolExecutionResult(stream_event=mcp_completed_event, sequence_number=sequence_number) elif function_name == "web_search": sequence_number += 1 - completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( + web_completion_event = OpenAIResponseObjectStreamResponseWebSearchCallCompleted( item_id=item_id, output_index=output_index, sequence_number=sequence_number, ) + yield ToolExecutionResult(stream_event=web_completion_event, sequence_number=sequence_number) elif function_name == "knowledge_search": sequence_number += 1 - completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted( + file_completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted( item_id=item_id, output_index=output_index, sequence_number=sequence_number, ) - - if completion_event: - yield ToolExecutionResult(stream_event=completion_event, sequence_number=sequence_number) + yield ToolExecutionResult(stream_event=file_completion_event, sequence_number=sequence_number) async def _build_result_messages( self, @@ -360,21 +386,18 @@ async def _build_result_messages( tool_kwargs: dict, ctx: ChatCompletionContext, error_exc: Exception | None, - result: any, + result: Any, has_error: bool, mcp_tool_to_server: dict[str, OpenAIResponseInputToolMCP] | None = None, - ) -> tuple[any, any]: + ) -> tuple[Any, Any]: """Build output and input messages from tool execution results.""" from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) # Build output message + message: Any if mcp_tool_to_server and function.name in mcp_tool_to_server: - from llama_stack.apis.agents.openai_responses import ( - OpenAIResponseOutputMessageMCPCall, - ) - message = OpenAIResponseOutputMessageMCPCall( id=item_id, arguments=function.arguments, @@ -383,10 +406,14 @@ async def _build_result_messages( ) if error_exc: message.error = str(error_exc) - elif (result and result.error_code and result.error_code > 0) or (result and result.error_message): - message.error = f"Error (code {result.error_code}): {result.error_message}" - elif result and result.content: - message.output = interleaved_content_as_str(result.content) + elif (result and (error_code := getattr(result, "error_code", None)) and error_code > 0) or ( + result and getattr(result, "error_message", None) + ): + ec = getattr(result, "error_code", "unknown") + em = getattr(result, "error_message", "") + message.error = f"Error (code {ec}): {em}" + elif result and (content := getattr(result, "content", None)): + message.output = interleaved_content_as_str(content) else: if function.name == "web_search": message = OpenAIResponseOutputMessageWebSearchToolCall( @@ -401,17 +428,17 @@ async def _build_result_messages( queries=[tool_kwargs.get("query", "")], status="completed", ) - if result and "document_ids" in result.metadata: + if result and (metadata := getattr(result, "metadata", None)) and "document_ids" in metadata: message.results = [] - for i, doc_id in enumerate(result.metadata["document_ids"]): - text = result.metadata["chunks"][i] if "chunks" in result.metadata else None - score = result.metadata["scores"][i] if "scores" in result.metadata else None + for i, doc_id in enumerate(metadata["document_ids"]): + text = metadata["chunks"][i] if "chunks" in metadata else None + score = metadata["scores"][i] if "scores" in metadata else None message.results.append( OpenAIResponseOutputMessageFileSearchToolCallResults( file_id=doc_id, filename=doc_id, - text=text, - score=score, + text=text if text is not None else "", + score=score if score is not None else 0.0, attributes={}, ) ) @@ -421,27 +448,32 @@ async def _build_result_messages( raise ValueError(f"Unknown tool {function.name} called") # Build input message - input_message = None - if result and result.content: - if isinstance(result.content, str): - content = result.content - elif isinstance(result.content, list): - content = [] - for item in result.content: + input_message: OpenAIToolMessageParam | None = None + if result and (result_content := getattr(result, "content", None)): + # all the mypy contortions here are still unsatisfactory with random Any typing + if isinstance(result_content, str): + msg_content: str | list[Any] = result_content + elif isinstance(result_content, list): + content_list: list[Any] = [] + for item in result_content: + part: Any if isinstance(item, TextContentItem): part = OpenAIChatCompletionContentPartTextParam(text=item.text) elif isinstance(item, ImageContentItem): if item.image.data: - url = f"data:image;base64,{item.image.data}" + url_value = f"data:image;base64,{item.image.data}" else: - url = item.image.url - part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url)) + url_value = str(item.image.url) if item.image.url else "" + part = OpenAIChatCompletionContentPartImageParam(image_url=OpenAIImageURL(url=url_value)) else: raise ValueError(f"Unknown result content type: {type(item)}") - content.append(part) + content_list.append(part) + msg_content = content_list else: - raise ValueError(f"Unknown result content type: {type(result.content)}") - input_message = OpenAIToolMessageParam(content=content, tool_call_id=tool_call_id) + raise ValueError(f"Unknown result content type: {type(result_content)}") + # OpenAIToolMessageParam accepts str | list[TextParam] but we may have images + # This is runtime-safe as the API accepts it, but mypy complains + input_message = OpenAIToolMessageParam(content=msg_content, tool_call_id=tool_call_id) # type: ignore[arg-type] else: text = str(error_exc) if error_exc else "Tool execution failed" input_message = OpenAIToolMessageParam(content=text, tool_call_id=tool_call_id)