Skip to content

Commit 9078e29

Browse files
authored
Fix: Correct streaming order for ReasoningItem and RawResponsesStreamEvent events (openai#1869)
1 parent 0442b82 commit 9078e29

File tree

5 files changed

+392
-9
lines changed

5 files changed

+392
-9
lines changed

src/agents/run.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from openai.types.responses.response_prompt_param import (
1414
ResponsePromptParam,
1515
)
16+
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
1617
from typing_extensions import NotRequired, TypedDict, Unpack
1718

1819
from ._run_impl import (
@@ -48,6 +49,7 @@
4849
HandoffCallItem,
4950
ItemHelpers,
5051
ModelResponse,
52+
ReasoningItem,
5153
RunItem,
5254
ToolCallItem,
5355
ToolCallItemTypes,
@@ -1097,6 +1099,7 @@ async def _run_single_turn_streamed(
10971099
server_conversation_tracker: _ServerConversationTracker | None = None,
10981100
) -> SingleStepResult:
10991101
emitted_tool_call_ids: set[str] = set()
1102+
emitted_reasoning_item_ids: set[str] = set()
11001103

11011104
if should_run_agent_start_hooks:
11021105
await asyncio.gather(
@@ -1178,6 +1181,9 @@ async def _run_single_turn_streamed(
11781181
conversation_id=conversation_id,
11791182
prompt=prompt_config,
11801183
):
1184+
# Emit the raw event ASAP
1185+
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
1186+
11811187
if isinstance(event, ResponseCompletedEvent):
11821188
usage = (
11831189
Usage(
@@ -1217,7 +1223,16 @@ async def _run_single_turn_streamed(
12171223
RunItemStreamEvent(item=tool_item, name="tool_called")
12181224
)
12191225

1220-
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
1226+
elif isinstance(output_item, ResponseReasoningItem):
1227+
reasoning_id: str | None = getattr(output_item, "id", None)
1228+
1229+
if reasoning_id and reasoning_id not in emitted_reasoning_item_ids:
1230+
emitted_reasoning_item_ids.add(reasoning_id)
1231+
1232+
reasoning_item = ReasoningItem(raw_item=output_item, agent=agent)
1233+
streamed_result._event_queue.put_nowait(
1234+
RunItemStreamEvent(item=reasoning_item, name="reasoning_item_created")
1235+
)
12211236

12221237
# Call hook just after the model response is finalized.
12231238
if final_response is not None:
@@ -1271,6 +1286,18 @@ async def _run_single_turn_streamed(
12711286
)
12721287
]
12731288

1289+
if emitted_reasoning_item_ids:
1290+
# Filter out reasoning items that were already emitted during streaming
1291+
items_to_filter = [
1292+
item
1293+
for item in items_to_filter
1294+
if not (
1295+
isinstance(item, ReasoningItem)
1296+
and (reasoning_id := getattr(item.raw_item, "id", None))
1297+
and reasoning_id in emitted_reasoning_item_ids
1298+
)
1299+
]
1300+
12741301
# Filter out HandoffCallItem to avoid duplicates (already sent earlier)
12751302
items_to_filter = [
12761303
item for item in items_to_filter if not isinstance(item, HandoffCallItem)

tests/fake_model.py

Lines changed: 170 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,33 @@
33
from collections.abc import AsyncIterator
44
from typing import Any
55

6-
from openai.types.responses import Response, ResponseCompletedEvent, ResponseUsage
6+
from openai.types.responses import (
7+
Response,
8+
ResponseCompletedEvent,
9+
ResponseContentPartAddedEvent,
10+
ResponseContentPartDoneEvent,
11+
ResponseCreatedEvent,
12+
ResponseFunctionCallArgumentsDeltaEvent,
13+
ResponseFunctionCallArgumentsDoneEvent,
14+
ResponseFunctionToolCall,
15+
ResponseInProgressEvent,
16+
ResponseOutputItemAddedEvent,
17+
ResponseOutputItemDoneEvent,
18+
ResponseOutputMessage,
19+
ResponseOutputText,
20+
ResponseReasoningSummaryPartAddedEvent,
21+
ResponseReasoningSummaryPartDoneEvent,
22+
ResponseReasoningSummaryTextDeltaEvent,
23+
ResponseReasoningSummaryTextDoneEvent,
24+
ResponseTextDeltaEvent,
25+
ResponseTextDoneEvent,
26+
ResponseUsage,
27+
)
28+
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
29+
from openai.types.responses.response_reasoning_summary_part_added_event import (
30+
Part as AddedEventPart,
31+
)
32+
from openai.types.responses.response_reasoning_summary_part_done_event import Part as DoneEventPart
733
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
834

935
from agents.agent_output import AgentOutputSchemaBase
@@ -143,10 +169,151 @@ async def stream_response(
143169
)
144170
raise output
145171

172+
response = get_response_obj(output, usage=self.hardcoded_usage)
173+
sequence_number = 0
174+
175+
yield ResponseCreatedEvent(
176+
type="response.created",
177+
response=response,
178+
sequence_number=sequence_number,
179+
)
180+
sequence_number += 1
181+
182+
yield ResponseInProgressEvent(
183+
type="response.in_progress",
184+
response=response,
185+
sequence_number=sequence_number,
186+
)
187+
sequence_number += 1
188+
189+
for output_index, output_item in enumerate(output):
190+
yield ResponseOutputItemAddedEvent(
191+
type="response.output_item.added",
192+
item=output_item,
193+
output_index=output_index,
194+
sequence_number=sequence_number,
195+
)
196+
sequence_number += 1
197+
198+
if isinstance(output_item, ResponseReasoningItem):
199+
if output_item.summary:
200+
for summary_index, summary in enumerate(output_item.summary):
201+
yield ResponseReasoningSummaryPartAddedEvent(
202+
type="response.reasoning_summary_part.added",
203+
item_id=output_item.id,
204+
output_index=output_index,
205+
summary_index=summary_index,
206+
part=AddedEventPart(text=summary.text, type=summary.type),
207+
sequence_number=sequence_number,
208+
)
209+
sequence_number += 1
210+
211+
yield ResponseReasoningSummaryTextDeltaEvent(
212+
type="response.reasoning_summary_text.delta",
213+
item_id=output_item.id,
214+
output_index=output_index,
215+
summary_index=summary_index,
216+
delta=summary.text,
217+
sequence_number=sequence_number,
218+
)
219+
sequence_number += 1
220+
221+
yield ResponseReasoningSummaryTextDoneEvent(
222+
type="response.reasoning_summary_text.done",
223+
item_id=output_item.id,
224+
output_index=output_index,
225+
summary_index=summary_index,
226+
text=summary.text,
227+
sequence_number=sequence_number,
228+
)
229+
sequence_number += 1
230+
231+
yield ResponseReasoningSummaryPartDoneEvent(
232+
type="response.reasoning_summary_part.done",
233+
item_id=output_item.id,
234+
output_index=output_index,
235+
summary_index=summary_index,
236+
part=DoneEventPart(text=summary.text, type=summary.type),
237+
sequence_number=sequence_number,
238+
)
239+
sequence_number += 1
240+
241+
elif isinstance(output_item, ResponseFunctionToolCall):
242+
yield ResponseFunctionCallArgumentsDeltaEvent(
243+
type="response.function_call_arguments.delta",
244+
item_id=output_item.call_id,
245+
output_index=output_index,
246+
delta=output_item.arguments,
247+
sequence_number=sequence_number,
248+
)
249+
sequence_number += 1
250+
251+
yield ResponseFunctionCallArgumentsDoneEvent(
252+
type="response.function_call_arguments.done",
253+
item_id=output_item.call_id,
254+
output_index=output_index,
255+
arguments=output_item.arguments,
256+
sequence_number=sequence_number,
257+
)
258+
sequence_number += 1
259+
260+
elif isinstance(output_item, ResponseOutputMessage):
261+
for content_index, content_part in enumerate(output_item.content):
262+
if isinstance(content_part, ResponseOutputText):
263+
yield ResponseContentPartAddedEvent(
264+
type="response.content_part.added",
265+
item_id=output_item.id,
266+
output_index=output_index,
267+
content_index=content_index,
268+
part=content_part,
269+
sequence_number=sequence_number,
270+
)
271+
sequence_number += 1
272+
273+
yield ResponseTextDeltaEvent(
274+
type="response.output_text.delta",
275+
item_id=output_item.id,
276+
output_index=output_index,
277+
content_index=content_index,
278+
delta=content_part.text,
279+
logprobs=[],
280+
sequence_number=sequence_number,
281+
)
282+
sequence_number += 1
283+
284+
yield ResponseTextDoneEvent(
285+
type="response.output_text.done",
286+
item_id=output_item.id,
287+
output_index=output_index,
288+
content_index=content_index,
289+
text=content_part.text,
290+
logprobs=[],
291+
sequence_number=sequence_number,
292+
)
293+
sequence_number += 1
294+
295+
yield ResponseContentPartDoneEvent(
296+
type="response.content_part.done",
297+
item_id=output_item.id,
298+
output_index=output_index,
299+
content_index=content_index,
300+
part=content_part,
301+
sequence_number=sequence_number,
302+
)
303+
sequence_number += 1
304+
305+
yield ResponseOutputItemDoneEvent(
306+
type="response.output_item.done",
307+
item=output_item,
308+
output_index=output_index,
309+
sequence_number=sequence_number,
310+
)
311+
sequence_number += 1
312+
146313
yield ResponseCompletedEvent(
147314
type="response.completed",
148-
response=get_response_obj(output, usage=self.hardcoded_usage),
149-
sequence_number=0,
315+
response=response,
316+
sequence_number=sequence_number,
150317
)
151318

152319

tests/fastapi/test_streaming_context.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,17 @@ async def test_streaming_context():
2525
body = (await r.aread()).decode("utf-8")
2626
lines = [line for line in body.splitlines() if line]
2727
assert lines == snapshot(
28-
["agent_updated_stream_event", "raw_response_event", "run_item_stream_event"]
28+
[
29+
"agent_updated_stream_event",
30+
"raw_response_event", # ResponseCreatedEvent
31+
"raw_response_event", # ResponseInProgressEvent
32+
"raw_response_event", # ResponseOutputItemAddedEvent
33+
"raw_response_event", # ResponseContentPartAddedEvent
34+
"raw_response_event", # ResponseTextDeltaEvent
35+
"raw_response_event", # ResponseTextDoneEvent
36+
"raw_response_event", # ResponseContentPartDoneEvent
37+
"raw_response_event", # ResponseOutputItemDoneEvent
38+
"raw_response_event", # ResponseCompletedEvent
39+
"run_item_stream_event", # MessageOutputItem
40+
]
2941
)

tests/test_agent_runner_streamed.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -695,11 +695,16 @@ async def test_streaming_events():
695695
# Now lets check the events
696696

697697
expected_item_type_map = {
698-
"tool_call": 2,
698+
# 3 tool_call_item events:
699+
# 1. get_function_tool_call("foo", ...)
700+
# 2. get_handoff_tool_call(agent_1) because handoffs are implemented via tool calls too
701+
# 3. get_function_tool_call("bar", ...)
702+
"tool_call": 3,
703+
# Only 2 outputs, handoff tool call doesn't have corresponding tool_call_output event
699704
"tool_call_output": 2,
700-
"message": 2,
701-
"handoff": 1,
702-
"handoff_output": 1,
705+
"message": 2, # get_text_message("a_message") + get_final_output_message(...)
706+
"handoff": 1, # get_handoff_tool_call(agent_1)
707+
"handoff_output": 1, # handoff_output_item
703708
}
704709

705710
total_expected_item_count = sum(expected_item_type_map.values())

0 commit comments

Comments
 (0)