diff --git a/examples/basic/stream_function_call_args.py b/examples/basic/stream_function_call_args.py index 46e72896c..3c3538772 100644 --- a/examples/basic/stream_function_call_args.py +++ b/examples/basic/stream_function_call_args.py @@ -35,7 +35,7 @@ async def main(): result = Runner.run_streamed( agent, - input="Create a Python web project called 'my-app' with FastAPI. Version 1.0.0, dependencies: fastapi, uvicorn" + input="Create a Python web project called 'my-app' with FastAPI. Version 1.0.0, dependencies: fastapi, uvicorn", ) # Track function calls for detailed output @@ -50,10 +50,7 @@ async def main(): function_name = getattr(event.data.item, "name", "unknown") call_id = getattr(event.data.item, "call_id", "unknown") - function_calls[call_id] = { - 'name': function_name, - 'arguments': "" - } + function_calls[call_id] = {"name": function_name, "arguments": ""} current_active_call_id = call_id print(f"\n📞 Function call streaming started: {function_name}()") print("📝 Arguments building...") @@ -61,12 +58,12 @@ async def main(): # Real-time argument streaming elif isinstance(event.data, ResponseFunctionCallArgumentsDeltaEvent): if current_active_call_id and current_active_call_id in function_calls: - function_calls[current_active_call_id]['arguments'] += event.data.delta + function_calls[current_active_call_id]["arguments"] += event.data.delta print(event.data.delta, end="", flush=True) # Function call completed elif event.data.type == "response.output_item.done": - if hasattr(event.data.item, 'call_id'): + if hasattr(event.data.item, "call_id"): call_id = getattr(event.data.item, "call_id", "unknown") if call_id in function_calls: function_info = function_calls[call_id] diff --git a/examples/customer_service/main.py b/examples/customer_service/main.py index 8ed218536..266a7e611 100644 --- a/examples/customer_service/main.py +++ b/examples/customer_service/main.py @@ -40,7 +40,10 @@ class AirlineAgentContext(BaseModel): ) async def faq_lookup_tool(question: str) -> str: question_lower = question.lower() - if any(keyword in question_lower for keyword in ["bag", "baggage", "luggage", "carry-on", "hand luggage", "hand carry"]): + if any( + keyword in question_lower + for keyword in ["bag", "baggage", "luggage", "carry-on", "hand luggage", "hand carry"] + ): return ( "You are allowed to bring one bag on the plane. " "It must be under 50 pounds and 22 inches x 14 inches x 9 inches." @@ -52,7 +55,10 @@ async def faq_lookup_tool(question: str) -> str: "Exit rows are rows 4 and 16. " "Rows 5-8 are Economy Plus, with extra legroom. " ) - elif any(keyword in question_lower for keyword in ["wifi", "internet", "wireless", "connectivity", "network", "online"]): + elif any( + keyword in question_lower + for keyword in ["wifi", "internet", "wireless", "connectivity", "network", "online"] + ): return "We have free wifi on the plane, join Airline-Wifi" return "I'm sorry, I don't know the answer to that question." diff --git a/examples/realtime/app/server.py b/examples/realtime/app/server.py index 73fcf3e56..04f3def43 100644 --- a/examples/realtime/app/server.py +++ b/examples/realtime/app/server.py @@ -4,11 +4,12 @@ import logging import struct from contextlib import asynccontextmanager -from typing import TYPE_CHECKING, Any, assert_never +from typing import TYPE_CHECKING, Any from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles +from typing_extensions import assert_never from agents.realtime import RealtimeRunner, RealtimeSession, RealtimeSessionEvent diff --git a/src/agents/handoffs.py b/src/agents/handoffs.py index 4d70f6058..2c52737ad 100644 --- a/src/agents/handoffs.py +++ b/src/agents/handoffs.py @@ -119,9 +119,9 @@ class Handoff(Generic[TContext, TAgent]): True, as it increases the likelihood of correct JSON input. """ - is_enabled: bool | Callable[ - [RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool] - ] = True + is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = ( + True + ) """Whether the handoff is enabled. Either a bool or a Callable that takes the run context and agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable a handoff based on your context/state.""" diff --git a/src/agents/model_settings.py b/src/agents/model_settings.py index 71e66ed84..f76d64266 100644 --- a/src/agents/model_settings.py +++ b/src/agents/model_settings.py @@ -55,7 +55,6 @@ class MCPToolChoice: ToolChoice: TypeAlias = Union[Literal["auto", "required", "none"], str, MCPToolChoice, None] - @dataclass class ModelSettings: """Settings to use when calling an LLM. diff --git a/src/agents/run.py b/src/agents/run.py index d0748e514..397dff223 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -4,9 +4,12 @@ import copy import inspect from dataclasses import dataclass, field -from typing import Any, Generic, cast +from typing import Any, Generic, cast, get_args -from openai.types.responses import ResponseCompletedEvent +from openai.types.responses import ( + ResponseCompletedEvent, + ResponseOutputItemAddedEvent, +) from openai.types.responses.response_prompt_param import ( ResponsePromptParam, ) @@ -41,7 +44,14 @@ OutputGuardrailResult, ) from .handoffs import Handoff, HandoffInputFilter, handoff -from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem +from .items import ( + ItemHelpers, + ModelResponse, + RunItem, + ToolCallItem, + ToolCallItemTypes, + TResponseInputItem, +) from .lifecycle import RunHooks from .logger import logger from .memory import Session @@ -50,7 +60,7 @@ from .models.multi_provider import MultiProvider from .result import RunResult, RunResultStreaming from .run_context import RunContextWrapper, TContext -from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent +from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent, RunItemStreamEvent from .tool import Tool from .tracing import Span, SpanError, agent_span, get_current_trace, trace from .tracing.span_data import AgentSpanData @@ -833,6 +843,8 @@ async def _run_single_turn_streamed( all_tools: list[Tool], previous_response_id: str | None, ) -> SingleStepResult: + emitted_tool_call_ids: set[str] = set() + if should_run_agent_start_hooks: await asyncio.gather( hooks.on_agent_start(context_wrapper, agent), @@ -897,16 +909,35 @@ async def _run_single_turn_streamed( ) context_wrapper.usage.add(usage) + if isinstance(event, ResponseOutputItemAddedEvent): + output_item = event.item + + if isinstance(output_item, _TOOL_CALL_TYPES): + call_id: str | None = getattr( + output_item, "call_id", getattr(output_item, "id", None) + ) + + if call_id and call_id not in emitted_tool_call_ids: + emitted_tool_call_ids.add(call_id) + + tool_item = ToolCallItem( + raw_item=cast(ToolCallItemTypes, output_item), + agent=agent, + ) + streamed_result._event_queue.put_nowait( + RunItemStreamEvent(item=tool_item, name="tool_called") + ) + streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event)) - # 2. At this point, the streaming is complete for this turn of the agent loop. if not final_response: raise ModelBehaviorError("Model did not produce a final response!") # 3. Now, we can process the turn as we do in the non-streaming case - return await cls._get_single_step_result_from_streamed_response( + single_step_result = await cls._get_single_step_result_from_response( agent=agent, - streamed_result=streamed_result, + original_input=streamed_result.input, + pre_step_items=streamed_result.new_items, new_response=final_response, output_schema=output_schema, all_tools=all_tools, @@ -917,6 +948,34 @@ async def _run_single_turn_streamed( tool_use_tracker=tool_use_tracker, ) + if emitted_tool_call_ids: + import dataclasses as _dc + + filtered_items = [ + item + for item in single_step_result.new_step_items + if not ( + isinstance(item, ToolCallItem) + and ( + call_id := getattr( + item.raw_item, "call_id", getattr(item.raw_item, "id", None) + ) + ) + and call_id in emitted_tool_call_ids + ) + ] + + single_step_result_filtered = _dc.replace( + single_step_result, new_step_items=filtered_items + ) + + RunImpl.stream_step_result_to_queue( + single_step_result_filtered, streamed_result._event_queue + ) + else: + RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue) + return single_step_result + @classmethod async def _run_single_turn( cls, @@ -1019,57 +1078,6 @@ async def _get_single_step_result_from_response( run_config=run_config, ) - @classmethod - async def _get_single_step_result_from_streamed_response( - cls, - *, - agent: Agent[TContext], - all_tools: list[Tool], - streamed_result: RunResultStreaming, - new_response: ModelResponse, - output_schema: AgentOutputSchemaBase | None, - handoffs: list[Handoff], - hooks: RunHooks[TContext], - context_wrapper: RunContextWrapper[TContext], - run_config: RunConfig, - tool_use_tracker: AgentToolUseTracker, - ) -> SingleStepResult: - - original_input = streamed_result.input - pre_step_items = streamed_result.new_items - event_queue = streamed_result._event_queue - - processed_response = RunImpl.process_model_response( - agent=agent, - all_tools=all_tools, - response=new_response, - output_schema=output_schema, - handoffs=handoffs, - ) - new_items_processed_response = processed_response.new_items - tool_use_tracker.add_tool_use(agent, processed_response.tools_used) - RunImpl.stream_step_items_to_queue(new_items_processed_response, event_queue) - - single_step_result = await RunImpl.execute_tools_and_side_effects( - agent=agent, - original_input=original_input, - pre_step_items=pre_step_items, - new_response=new_response, - processed_response=processed_response, - output_schema=output_schema, - hooks=hooks, - context_wrapper=context_wrapper, - run_config=run_config, - ) - new_step_items = [ - item - for item in single_step_result.new_step_items - if item not in new_items_processed_response - ] - RunImpl.stream_step_items_to_queue(new_step_items, event_queue) - - return single_step_result - @classmethod async def _run_input_guardrails( cls, @@ -1287,3 +1295,5 @@ async def _save_result_to_session( DEFAULT_AGENT_RUNNER = AgentRunner() + +_TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes) diff --git a/src/agents/tracing/processors.py b/src/agents/tracing/processors.py index 32fd290ec..126c71498 100644 --- a/src/agents/tracing/processors.py +++ b/src/agents/tracing/processors.py @@ -70,8 +70,8 @@ def set_api_key(self, api_key: str): client. """ # Clear the cached property if it exists - if 'api_key' in self.__dict__: - del self.__dict__['api_key'] + if "api_key" in self.__dict__: + del self.__dict__["api_key"] # Update the private attribute self._api_key = api_key diff --git a/tests/test_agent_clone_shallow_copy.py b/tests/test_agent_clone_shallow_copy.py index fdf9e0247..44b41bd3d 100644 --- a/tests/test_agent_clone_shallow_copy.py +++ b/tests/test_agent_clone_shallow_copy.py @@ -5,6 +5,7 @@ def greet(name: str) -> str: return f"Hello, {name}!" + def test_agent_clone_shallow_copy(): """Test that clone creates shallow copy with tools.copy() workaround""" target_agent = Agent(name="Target") @@ -16,9 +17,7 @@ def test_agent_clone_shallow_copy(): ) cloned = original.clone( - name="Cloned", - tools=original.tools.copy(), - handoffs=original.handoffs.copy() + name="Cloned", tools=original.tools.copy(), handoffs=original.handoffs.copy() ) # Basic assertions diff --git a/tests/test_stream_events.py b/tests/test_stream_events.py index 11feb9fe0..0f85b63f8 100644 --- a/tests/test_stream_events.py +++ b/tests/test_stream_events.py @@ -14,6 +14,7 @@ async def foo() -> str: await asyncio.sleep(3) return "success!" + @pytest.mark.asyncio async def test_stream_events_main(): model = FakeModel()