Skip to content

Commit 5daa8e1

Browse files
committed
Feat: separate tool_call_item and tool_call_output_item in stream events
1 parent 18cb55e commit 5daa8e1

File tree

3 files changed

+130
-9
lines changed

3 files changed

+130
-9
lines changed

src/agents/_run_impl.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -907,12 +907,12 @@ async def run_single_output_guardrail(
907907
return result
908908

909909
@classmethod
910-
def stream_step_result_to_queue(
910+
def stream_step_items_to_queue(
911911
cls,
912-
step_result: SingleStepResult,
912+
new_step_items: list[RunItem],
913913
queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel],
914914
):
915-
for item in step_result.new_step_items:
915+
for item in new_step_items:
916916
if isinstance(item, MessageOutputItem):
917917
event = RunItemStreamEvent(item=item, name="message_output_created")
918918
elif isinstance(item, HandoffCallItem):
@@ -937,6 +937,14 @@ def stream_step_result_to_queue(
937937
if event:
938938
queue.put_nowait(event)
939939

940+
@classmethod
941+
def stream_step_result_to_queue(
942+
cls,
943+
step_result: SingleStepResult,
944+
queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel],
945+
):
946+
cls.stream_step_items_to_queue(step_result.new_step_items, queue)
947+
940948
@classmethod
941949
async def _check_for_final_output_from_tools(
942950
cls,

src/agents/run.py

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -854,10 +854,9 @@ async def _run_single_turn_streamed(
854854
raise ModelBehaviorError("Model did not produce a final response!")
855855

856856
# 3. Now, we can process the turn as we do in the non-streaming case
857-
single_step_result = await cls._get_single_step_result_from_response(
857+
return await cls._get_single_step_result_from_streamed_response(
858858
agent=agent,
859-
original_input=streamed_result.input,
860-
pre_step_items=streamed_result.new_items,
859+
streamed_result=streamed_result,
861860
new_response=final_response,
862861
output_schema=output_schema,
863862
all_tools=all_tools,
@@ -868,9 +867,6 @@ async def _run_single_turn_streamed(
868867
tool_use_tracker=tool_use_tracker,
869868
)
870869

871-
RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
872-
return single_step_result
873-
874870
@classmethod
875871
async def _run_single_turn(
876872
cls,
@@ -973,6 +969,57 @@ async def _get_single_step_result_from_response(
973969
run_config=run_config,
974970
)
975971

972+
@classmethod
973+
async def _get_single_step_result_from_streamed_response(
974+
cls,
975+
*,
976+
agent: Agent[TContext],
977+
all_tools: list[Tool],
978+
streamed_result: RunResultStreaming,
979+
new_response: ModelResponse,
980+
output_schema: AgentOutputSchemaBase | None,
981+
handoffs: list[Handoff],
982+
hooks: RunHooks[TContext],
983+
context_wrapper: RunContextWrapper[TContext],
984+
run_config: RunConfig,
985+
tool_use_tracker: AgentToolUseTracker,
986+
) -> SingleStepResult:
987+
988+
original_input = streamed_result.input
989+
pre_step_items = streamed_result.new_items
990+
event_queue = streamed_result._event_queue
991+
992+
processed_response = RunImpl.process_model_response(
993+
agent=agent,
994+
all_tools=all_tools,
995+
response=new_response,
996+
output_schema=output_schema,
997+
handoffs=handoffs,
998+
)
999+
new_items_processed_response = processed_response.new_items
1000+
tool_use_tracker.add_tool_use(agent, processed_response.tools_used)
1001+
RunImpl.stream_step_items_to_queue(new_items_processed_response, event_queue)
1002+
1003+
single_step_result = await RunImpl.execute_tools_and_side_effects(
1004+
agent=agent,
1005+
original_input=original_input,
1006+
pre_step_items=pre_step_items,
1007+
new_response=new_response,
1008+
processed_response=processed_response,
1009+
output_schema=output_schema,
1010+
hooks=hooks,
1011+
context_wrapper=context_wrapper,
1012+
run_config=run_config,
1013+
)
1014+
new_step_items = [
1015+
item
1016+
for item in single_step_result.new_step_items
1017+
if item not in new_items_processed_response
1018+
]
1019+
RunImpl.stream_step_items_to_queue(new_step_items, event_queue)
1020+
1021+
return single_step_result
1022+
9761023
@classmethod
9771024
async def _run_input_guardrails(
9781025
cls,

tests/test_stream_events.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import asyncio
2+
import time
3+
4+
import pytest
5+
6+
from agents import Agent, ItemHelpers, Runner, function_tool
7+
8+
from .fake_model import FakeModel
9+
from .test_responses import get_function_tool_call, get_text_message
10+
11+
12+
@function_tool
13+
async def foo() -> str:
14+
await asyncio.sleep(3)
15+
return "success!"
16+
17+
@pytest.mark.asyncio
18+
async def test_stream_events_main():
19+
model = FakeModel()
20+
agent = Agent(
21+
name="Joker",
22+
model=model,
23+
tools=[foo],
24+
)
25+
26+
model.add_multiple_turn_outputs(
27+
[
28+
# First turn: a message and tool call
29+
[
30+
get_text_message("a_message"),
31+
get_function_tool_call("foo", ""),
32+
],
33+
# Second turn: text message
34+
[get_text_message("done")],
35+
]
36+
)
37+
38+
result = Runner.run_streamed(
39+
agent,
40+
input="Hello",
41+
)
42+
print("=== Run starting ===")
43+
tool_call_start_time = -1
44+
tool_call_end_time = -1
45+
async for event in result.stream_events():
46+
# We'll ignore the raw responses event deltas
47+
if event.type == "raw_response_event":
48+
continue
49+
elif event.type == "agent_updated_stream_event":
50+
print(f"Agent updated: {event.new_agent.name}")
51+
elif event.type == "run_item_stream_event":
52+
if event.item.type == "tool_call_item":
53+
tool_call_start_time = time.time_ns()
54+
print(f"-- Tool was called at {tool_call_start_time}")
55+
elif event.item.type == "tool_call_output_item":
56+
tool_call_end_time = time.time_ns()
57+
print(f"-- Tool output: {event.item.output} at {tool_call_end_time}")
58+
elif event.item.type == "message_output_item":
59+
print(
60+
f"-- Message output:\n {ItemHelpers.text_message_output(event.item)}"
61+
)
62+
63+
print("=== Run complete ===")
64+
assert tool_call_start_time > 0, "tool_call_item was not observed"
65+
assert tool_call_end_time > 0, "tool_call_output_item was not observed"
66+
assert tool_call_start_time < tool_call_end_time, "Tool call ended before or equals it started?"

0 commit comments

Comments
 (0)