Skip to content

Commit e8d311b

Browse files
habemaseratch
andauthored
Fix: Emit tool_called events immediately in streaming runs (#1300)
Co-authored-by: Kazuhiro Sera <[email protected]>
1 parent 7dda9d8 commit e8d311b

File tree

8 files changed

+88
-24
lines changed

8 files changed

+88
-24
lines changed

examples/basic/stream_function_call_args.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ async def main():
3535

3636
result = Runner.run_streamed(
3737
agent,
38-
input="Create a Python web project called 'my-app' with FastAPI. Version 1.0.0, dependencies: fastapi, uvicorn"
38+
input="Create a Python web project called 'my-app' with FastAPI. Version 1.0.0, dependencies: fastapi, uvicorn",
3939
)
4040

4141
# Track function calls for detailed output
@@ -50,23 +50,20 @@ async def main():
5050
function_name = getattr(event.data.item, "name", "unknown")
5151
call_id = getattr(event.data.item, "call_id", "unknown")
5252

53-
function_calls[call_id] = {
54-
'name': function_name,
55-
'arguments': ""
56-
}
53+
function_calls[call_id] = {"name": function_name, "arguments": ""}
5754
current_active_call_id = call_id
5855
print(f"\n📞 Function call streaming started: {function_name}()")
5956
print("📝 Arguments building...")
6057

6158
# Real-time argument streaming
6259
elif isinstance(event.data, ResponseFunctionCallArgumentsDeltaEvent):
6360
if current_active_call_id and current_active_call_id in function_calls:
64-
function_calls[current_active_call_id]['arguments'] += event.data.delta
61+
function_calls[current_active_call_id]["arguments"] += event.data.delta
6562
print(event.data.delta, end="", flush=True)
6663

6764
# Function call completed
6865
elif event.data.type == "response.output_item.done":
69-
if hasattr(event.data.item, 'call_id'):
66+
if hasattr(event.data.item, "call_id"):
7067
call_id = getattr(event.data.item, "call_id", "unknown")
7168
if call_id in function_calls:
7269
function_info = function_calls[call_id]

examples/customer_service/main.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ class AirlineAgentContext(BaseModel):
4040
)
4141
async def faq_lookup_tool(question: str) -> str:
4242
question_lower = question.lower()
43-
if any(keyword in question_lower for keyword in ["bag", "baggage", "luggage", "carry-on", "hand luggage", "hand carry"]):
43+
if any(
44+
keyword in question_lower
45+
for keyword in ["bag", "baggage", "luggage", "carry-on", "hand luggage", "hand carry"]
46+
):
4447
return (
4548
"You are allowed to bring one bag on the plane. "
4649
"It must be under 50 pounds and 22 inches x 14 inches x 9 inches."
@@ -52,7 +55,10 @@ async def faq_lookup_tool(question: str) -> str:
5255
"Exit rows are rows 4 and 16. "
5356
"Rows 5-8 are Economy Plus, with extra legroom. "
5457
)
55-
elif any(keyword in question_lower for keyword in ["wifi", "internet", "wireless", "connectivity", "network", "online"]):
58+
elif any(
59+
keyword in question_lower
60+
for keyword in ["wifi", "internet", "wireless", "connectivity", "network", "online"]
61+
):
5662
return "We have free wifi on the plane, join Airline-Wifi"
5763
return "I'm sorry, I don't know the answer to that question."
5864

src/agents/handoffs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ class Handoff(Generic[TContext, TAgent]):
119119
True, as it increases the likelihood of correct JSON input.
120120
"""
121121

122-
is_enabled: bool | Callable[
123-
[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]
124-
] = True
122+
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = (
123+
True
124+
)
125125
"""Whether the handoff is enabled. Either a bool or a Callable that takes the run context and
126126
agent and returns whether the handoff is enabled. You can use this to dynamically enable/disable
127127
a handoff based on your context/state."""

src/agents/model_settings.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ class MCPToolChoice:
5555
ToolChoice: TypeAlias = Union[Literal["auto", "required", "none"], str, MCPToolChoice, None]
5656

5757

58-
5958
@dataclass
6059
class ModelSettings:
6160
"""Settings to use when calling an LLM.

src/agents/run.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
import asyncio
44
import inspect
55
from dataclasses import dataclass, field
6-
from typing import Any, Callable, Generic, cast
6+
from typing import Any, Callable, Generic, cast, get_args
77

8-
from openai.types.responses import ResponseCompletedEvent
8+
from openai.types.responses import (
9+
ResponseCompletedEvent,
10+
ResponseOutputItemAddedEvent,
11+
)
912
from openai.types.responses.response_prompt_param import (
1013
ResponsePromptParam,
1114
)
@@ -40,7 +43,14 @@
4043
OutputGuardrailResult,
4144
)
4245
from .handoffs import Handoff, HandoffInputFilter, handoff
43-
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
46+
from .items import (
47+
ItemHelpers,
48+
ModelResponse,
49+
RunItem,
50+
ToolCallItem,
51+
ToolCallItemTypes,
52+
TResponseInputItem,
53+
)
4454
from .lifecycle import RunHooks
4555
from .logger import logger
4656
from .memory import Session
@@ -49,7 +59,7 @@
4959
from .models.multi_provider import MultiProvider
5060
from .result import RunResult, RunResultStreaming
5161
from .run_context import RunContextWrapper, TContext
52-
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
62+
from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent, RunItemStreamEvent
5363
from .tool import Tool
5464
from .tracing import Span, SpanError, agent_span, get_current_trace, trace
5565
from .tracing.span_data import AgentSpanData
@@ -905,6 +915,8 @@ async def _run_single_turn_streamed(
905915
all_tools: list[Tool],
906916
previous_response_id: str | None,
907917
) -> SingleStepResult:
918+
emitted_tool_call_ids: set[str] = set()
919+
908920
if should_run_agent_start_hooks:
909921
await asyncio.gather(
910922
hooks.on_agent_start(context_wrapper, agent),
@@ -984,6 +996,25 @@ async def _run_single_turn_streamed(
984996
)
985997
context_wrapper.usage.add(usage)
986998

999+
if isinstance(event, ResponseOutputItemAddedEvent):
1000+
output_item = event.item
1001+
1002+
if isinstance(output_item, _TOOL_CALL_TYPES):
1003+
call_id: str | None = getattr(
1004+
output_item, "call_id", getattr(output_item, "id", None)
1005+
)
1006+
1007+
if call_id and call_id not in emitted_tool_call_ids:
1008+
emitted_tool_call_ids.add(call_id)
1009+
1010+
tool_item = ToolCallItem(
1011+
raw_item=cast(ToolCallItemTypes, output_item),
1012+
agent=agent,
1013+
)
1014+
streamed_result._event_queue.put_nowait(
1015+
RunItemStreamEvent(item=tool_item, name="tool_called")
1016+
)
1017+
9871018
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
9881019

9891020
# Call hook just after the model response is finalized.
@@ -995,9 +1026,10 @@ async def _run_single_turn_streamed(
9951026
raise ModelBehaviorError("Model did not produce a final response!")
9961027

9971028
# 3. Now, we can process the turn as we do in the non-streaming case
998-
return await cls._get_single_step_result_from_streamed_response(
1029+
single_step_result = await cls._get_single_step_result_from_response(
9991030
agent=agent,
1000-
streamed_result=streamed_result,
1031+
original_input=streamed_result.input,
1032+
pre_step_items=streamed_result.new_items,
10011033
new_response=final_response,
10021034
output_schema=output_schema,
10031035
all_tools=all_tools,
@@ -1008,6 +1040,34 @@ async def _run_single_turn_streamed(
10081040
tool_use_tracker=tool_use_tracker,
10091041
)
10101042

1043+
if emitted_tool_call_ids:
1044+
import dataclasses as _dc
1045+
1046+
filtered_items = [
1047+
item
1048+
for item in single_step_result.new_step_items
1049+
if not (
1050+
isinstance(item, ToolCallItem)
1051+
and (
1052+
call_id := getattr(
1053+
item.raw_item, "call_id", getattr(item.raw_item, "id", None)
1054+
)
1055+
)
1056+
and call_id in emitted_tool_call_ids
1057+
)
1058+
]
1059+
1060+
single_step_result_filtered = _dc.replace(
1061+
single_step_result, new_step_items=filtered_items
1062+
)
1063+
1064+
RunImpl.stream_step_result_to_queue(
1065+
single_step_result_filtered, streamed_result._event_queue
1066+
)
1067+
else:
1068+
RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
1069+
return single_step_result
1070+
10111071
@classmethod
10121072
async def _run_single_turn(
10131073
cls,
@@ -1397,9 +1457,11 @@ async def _save_result_to_session(
13971457

13981458

13991459
DEFAULT_AGENT_RUNNER = AgentRunner()
1460+
_TOOL_CALL_TYPES: tuple[type, ...] = get_args(ToolCallItemTypes)
14001461

14011462

14021463
def _copy_str_or_list(input: str | list[TResponseInputItem]) -> str | list[TResponseInputItem]:
14031464
if isinstance(input, str):
14041465
return input
14051466
return input.copy()
1467+

src/agents/tracing/processors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def set_api_key(self, api_key: str):
7070
client.
7171
"""
7272
# Clear the cached property if it exists
73-
if 'api_key' in self.__dict__:
74-
del self.__dict__['api_key']
73+
if "api_key" in self.__dict__:
74+
del self.__dict__["api_key"]
7575

7676
# Update the private attribute
7777
self._api_key = api_key

tests/test_agent_clone_shallow_copy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
def greet(name: str) -> str:
66
return f"Hello, {name}!"
77

8+
89
def test_agent_clone_shallow_copy():
910
"""Test that clone creates shallow copy with tools.copy() workaround"""
1011
target_agent = Agent(name="Target")
@@ -16,9 +17,7 @@ def test_agent_clone_shallow_copy():
1617
)
1718

1819
cloned = original.clone(
19-
name="Cloned",
20-
tools=original.tools.copy(),
21-
handoffs=original.handoffs.copy()
20+
name="Cloned", tools=original.tools.copy(), handoffs=original.handoffs.copy()
2221
)
2322

2423
# Basic assertions

tests/test_stream_events.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ async def foo() -> str:
1414
await asyncio.sleep(3)
1515
return "success!"
1616

17+
1718
@pytest.mark.asyncio
1819
async def test_stream_events_main():
1920
model = FakeModel()

0 commit comments

Comments
 (0)