Skip to content

Commit 6efb3fa

Browse files
committed
fix: track request count when streaming fails mid-request (fixes #1973)
## Problem When streaming fails (API errors, connection drops, context window exceeded, etc.), usage tracking was completely lost. The issue occurs because usage is only accumulated when ResponseCompletedEvent arrives. If the model provider raises an exception before yielding ResponseCompletedEvent, the async for loop exits without ever updating usage. ## Solution Wrap the streaming loop in try-except to ensure at least the request count is tracked when streaming fails. While we cannot estimate token counts without introducing dependencies (like tiktoken) or making assumptions about the model, tracking the request count provides valuable information for: - Monitoring API call frequency - Debugging failed requests - Cost estimation (users know a request was made) Token counts remain at 0 when streaming fails before ResponseCompletedEvent, which accurately reflects that we don't have that information. ## Changes - src/agents/run.py: Added try-except around streaming loop - Tracks request count (Usage(requests=1)) on exception if no response received - Preserves existing behavior for successful streaming - tests/test_usage_tracking_on_error.py: Comprehensive test coverage - Test request tracking on streaming error - Test normal usage tracking still works - Test multi-turn scenarios with error Fixes #1973
1 parent 03dca68 commit 6efb3fa

File tree

2 files changed

+209
-58
lines changed

2 files changed

+209
-58
lines changed

src/agents/run.py

Lines changed: 66 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,72 +1234,80 @@ async def _run_single_turn_streamed(
12341234
)
12351235

12361236
# 1. Stream the output events
1237-
async for event in model.stream_response(
1238-
filtered.instructions,
1239-
filtered.input,
1240-
model_settings,
1241-
all_tools,
1242-
output_schema,
1243-
handoffs,
1244-
get_model_tracing_impl(
1245-
run_config.tracing_disabled, run_config.trace_include_sensitive_data
1246-
),
1247-
previous_response_id=previous_response_id,
1248-
conversation_id=conversation_id,
1249-
prompt=prompt_config,
1250-
):
1251-
# Emit the raw event ASAP
1252-
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
1253-
1254-
if isinstance(event, ResponseCompletedEvent):
1255-
usage = (
1256-
Usage(
1257-
requests=1,
1258-
input_tokens=event.response.usage.input_tokens,
1259-
output_tokens=event.response.usage.output_tokens,
1260-
total_tokens=event.response.usage.total_tokens,
1261-
input_tokens_details=event.response.usage.input_tokens_details,
1262-
output_tokens_details=event.response.usage.output_tokens_details,
1237+
try:
1238+
async for event in model.stream_response(
1239+
filtered.instructions,
1240+
filtered.input,
1241+
model_settings,
1242+
all_tools,
1243+
output_schema,
1244+
handoffs,
1245+
get_model_tracing_impl(
1246+
run_config.tracing_disabled, run_config.trace_include_sensitive_data
1247+
),
1248+
previous_response_id=previous_response_id,
1249+
conversation_id=conversation_id,
1250+
prompt=prompt_config,
1251+
):
1252+
# Emit the raw event ASAP
1253+
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
1254+
1255+
if isinstance(event, ResponseCompletedEvent):
1256+
usage = (
1257+
Usage(
1258+
requests=1,
1259+
input_tokens=event.response.usage.input_tokens,
1260+
output_tokens=event.response.usage.output_tokens,
1261+
total_tokens=event.response.usage.total_tokens,
1262+
input_tokens_details=event.response.usage.input_tokens_details,
1263+
output_tokens_details=event.response.usage.output_tokens_details,
1264+
)
1265+
if event.response.usage
1266+
else Usage()
12631267
)
1264-
if event.response.usage
1265-
else Usage()
1266-
)
1267-
final_response = ModelResponse(
1268-
output=event.response.output,
1269-
usage=usage,
1270-
response_id=event.response.id,
1271-
)
1272-
context_wrapper.usage.add(usage)
1273-
1274-
if isinstance(event, ResponseOutputItemDoneEvent):
1275-
output_item = event.item
1276-
1277-
if isinstance(output_item, _TOOL_CALL_TYPES):
1278-
call_id: str | None = getattr(
1279-
output_item, "call_id", getattr(output_item, "id", None)
1268+
final_response = ModelResponse(
1269+
output=event.response.output,
1270+
usage=usage,
1271+
response_id=event.response.id,
12801272
)
1273+
context_wrapper.usage.add(usage)
12811274

1282-
if call_id and call_id not in emitted_tool_call_ids:
1283-
emitted_tool_call_ids.add(call_id)
1275+
if isinstance(event, ResponseOutputItemDoneEvent):
1276+
output_item = event.item
12841277

1285-
tool_item = ToolCallItem(
1286-
raw_item=cast(ToolCallItemTypes, output_item),
1287-
agent=agent,
1288-
)
1289-
streamed_result._event_queue.put_nowait(
1290-
RunItemStreamEvent(item=tool_item, name="tool_called")
1278+
if isinstance(output_item, _TOOL_CALL_TYPES):
1279+
call_id: str | None = getattr(
1280+
output_item, "call_id", getattr(output_item, "id", None)
12911281
)
12921282

1293-
elif isinstance(output_item, ResponseReasoningItem):
1294-
reasoning_id: str | None = getattr(output_item, "id", None)
1283+
if call_id and call_id not in emitted_tool_call_ids:
1284+
emitted_tool_call_ids.add(call_id)
1285+
1286+
tool_item = ToolCallItem(
1287+
raw_item=cast(ToolCallItemTypes, output_item),
1288+
agent=agent,
1289+
)
1290+
streamed_result._event_queue.put_nowait(
1291+
RunItemStreamEvent(item=tool_item, name="tool_called")
1292+
)
1293+
1294+
elif isinstance(output_item, ResponseReasoningItem):
1295+
reasoning_id: str | None = getattr(output_item, "id", None)
12951296

1296-
if reasoning_id and reasoning_id not in emitted_reasoning_item_ids:
1297-
emitted_reasoning_item_ids.add(reasoning_id)
1297+
if reasoning_id and reasoning_id not in emitted_reasoning_item_ids:
1298+
emitted_reasoning_item_ids.add(reasoning_id)
12981299

1299-
reasoning_item = ReasoningItem(raw_item=output_item, agent=agent)
1300-
streamed_result._event_queue.put_nowait(
1301-
RunItemStreamEvent(item=reasoning_item, name="reasoning_item_created")
1302-
)
1300+
reasoning_item = ReasoningItem(raw_item=output_item, agent=agent)
1301+
streamed_result._event_queue.put_nowait(
1302+
RunItemStreamEvent(item=reasoning_item, name="reasoning_item_created")
1303+
)
1304+
except Exception:
1305+
# Track that a request was made, even if we didn't get usage info from the response
1306+
# This ensures usage tracking is not completely lost when streaming fails
1307+
# (Issue #1973)
1308+
if final_response is None:
1309+
context_wrapper.usage.add(Usage(requests=1))
1310+
raise
13031311

13041312
# Call hook just after the model response is finalized.
13051313
if final_response is not None:
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""Test that usage tracking works correctly when streaming fails.
2+
3+
This addresses Issue #1973: Usage tracking lost when streaming fails mid-request.
4+
"""
5+
6+
import pytest
7+
8+
from agents import Agent, Runner
9+
10+
from .fake_model import FakeModel
11+
12+
13+
@pytest.mark.asyncio
14+
async def test_usage_tracking_requests_on_streaming_error():
15+
"""Test that at least request count is tracked when streaming fails.
16+
17+
This addresses Issue #1973: When the model raises an error during streaming,
18+
we should track that a request was made, even if token counts are unavailable.
19+
"""
20+
model = FakeModel()
21+
22+
# Simulate a streaming failure (e.g., context window exceeded, connection drop)
23+
model.set_next_output(RuntimeError("Context window exceeded"))
24+
25+
agent = Agent(
26+
name="test_agent",
27+
model=model,
28+
)
29+
30+
# Run the agent and expect it to fail
31+
with pytest.raises(RuntimeError):
32+
result = Runner.run_streamed(agent, input="Test input that consumes tokens")
33+
async for _ in result.stream_events():
34+
pass
35+
36+
# FIXED: Request count should be tracked even when streaming fails
37+
assert result.context_wrapper.usage.requests == 1, "Request count should be tracked on error"
38+
39+
# Token counts are unavailable when streaming fails before ResponseCompletedEvent
40+
assert result.context_wrapper.usage.input_tokens == 0
41+
assert result.context_wrapper.usage.output_tokens == 0
42+
assert result.context_wrapper.usage.total_tokens == 0
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_usage_tracking_preserved_on_success():
47+
"""Test that normal usage tracking still works correctly after the fix.
48+
49+
This ensures our fix doesn't break the normal case where streaming succeeds.
50+
"""
51+
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
52+
53+
from agents.usage import Usage
54+
55+
from .test_responses import get_text_message
56+
57+
model = FakeModel()
58+
59+
# Set custom usage to verify it's tracked correctly
60+
model.set_hardcoded_usage(
61+
Usage(
62+
requests=1,
63+
input_tokens=100,
64+
output_tokens=50,
65+
total_tokens=150,
66+
input_tokens_details=InputTokensDetails(cached_tokens=10),
67+
output_tokens_details=OutputTokensDetails(reasoning_tokens=5),
68+
)
69+
)
70+
71+
# Simulate successful streaming
72+
model.set_next_output([get_text_message("Success")])
73+
74+
agent = Agent(
75+
name="test_agent",
76+
model=model,
77+
)
78+
79+
result = Runner.run_streamed(agent, input="Test input")
80+
async for _ in result.stream_events():
81+
pass
82+
83+
# Usage should be tracked correctly in the success case
84+
assert result.context_wrapper.usage.requests == 1
85+
assert result.context_wrapper.usage.input_tokens == 100
86+
assert result.context_wrapper.usage.output_tokens == 50
87+
assert result.context_wrapper.usage.total_tokens == 150
88+
# Note: FakeModel doesn't fully support token_details, so we only test the main counts
89+
90+
91+
@pytest.mark.asyncio
92+
async def test_usage_tracking_multi_turn_with_error():
93+
"""Test usage tracking across multiple turns when an error occurs.
94+
95+
This ensures that usage from successful turns is preserved even when a later turn fails.
96+
"""
97+
import json
98+
99+
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
100+
101+
from agents.usage import Usage
102+
103+
from .test_responses import get_function_tool, get_function_tool_call, get_text_message
104+
105+
model = FakeModel()
106+
107+
# First turn: successful with usage
108+
model.set_hardcoded_usage(
109+
Usage(
110+
requests=1,
111+
input_tokens=100,
112+
output_tokens=50,
113+
total_tokens=150,
114+
input_tokens_details=InputTokensDetails(cached_tokens=0),
115+
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
116+
)
117+
)
118+
119+
agent = Agent(
120+
name="test_agent",
121+
model=model,
122+
tools=[get_function_tool("test_tool", "tool_result")],
123+
)
124+
125+
model.add_multiple_turn_outputs(
126+
[
127+
# First turn: successful tool call
128+
[get_function_tool_call("test_tool", json.dumps({"arg": "value"}))],
129+
# Second turn: error
130+
RuntimeError("API error on second turn"),
131+
]
132+
)
133+
134+
with pytest.raises(RuntimeError):
135+
result = Runner.run_streamed(agent, input="Test input")
136+
async for _ in result.stream_events():
137+
pass
138+
139+
# Usage should include first turn's usage + second turn's request count
140+
assert result.context_wrapper.usage.requests == 2, "Should track both turns"
141+
assert result.context_wrapper.usage.input_tokens == 100, "Should preserve first turn's tokens"
142+
assert result.context_wrapper.usage.output_tokens == 50, "Should preserve first turn's tokens"
143+
assert result.context_wrapper.usage.total_tokens == 150, "Should preserve first turn's tokens"

0 commit comments

Comments
 (0)