Skip to content

Commit 7aa08f5

Browse files
Propagate OpenAI request id into streaming traces
1 parent 503a6ea commit 7aa08f5

File tree

2 files changed

+76
-6
lines changed

2 files changed

+76
-6
lines changed

src/agents/models/openai_responses.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import TYPE_CHECKING, Any, Literal, cast, overload
88

99
from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven
10+
from openai._models import add_request_id
1011
from openai.types import ChatModel
1112
from openai.types.responses import (
1213
Response,
@@ -180,16 +181,26 @@ async def stream_response(
180181
prompt=prompt,
181182
)
182183

184+
request_id = None
185+
stream_response = getattr(stream, "response", None)
186+
if stream_response is not None:
187+
headers = getattr(stream_response, "headers", None)
188+
if headers is not None:
189+
request_id = headers.get("x-request-id")
190+
183191
final_response: Response | None = None
184192

185193
async for chunk in stream:
186194
if isinstance(chunk, ResponseCompletedEvent):
187195
final_response = chunk.response
188196
yield chunk
189197

190-
if final_response and tracing.include_data():
191-
span_response.span_data.response = final_response
192-
span_response.span_data.input = input
198+
if final_response:
199+
if request_id:
200+
add_request_id(final_response, request_id)
201+
if tracing.include_data():
202+
span_response.span_data.response = final_response
203+
span_response.span_data.input = input
193204

194205
except Exception as e:
195206
span_response.set_error(

tests/test_agent_tracing.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
from __future__ import annotations
22

33
import asyncio
4+
from types import SimpleNamespace
45

56
import pytest
67
from inline_snapshot import snapshot
8+
from openai.types.responses import ResponseCompletedEvent
79

8-
from agents import Agent, RunConfig, Runner, trace
10+
from agents import Agent, OpenAIResponsesModel, RunConfig, Runner, trace
11+
from agents.tracing import ResponseSpanData
912

10-
from .fake_model import FakeModel
13+
from .fake_model import FakeModel, get_response_obj
1114
from .test_responses import get_text_message
12-
from .testing_processor import assert_no_traces, fetch_normalized_spans
15+
from .testing_processor import (
16+
assert_no_traces,
17+
fetch_normalized_spans,
18+
fetch_ordered_spans,
19+
)
1320

1421

1522
@pytest.mark.asyncio
@@ -292,6 +299,58 @@ async def test_streaming_single_run_is_single_trace():
292299
)
293300

294301

302+
@pytest.mark.asyncio
303+
@pytest.mark.allow_call_model_methods
304+
async def test_streamed_response_request_id_recorded():
305+
request_id = "req_test_123"
306+
307+
class DummyStream:
308+
def __init__(self) -> None:
309+
self.response = SimpleNamespace(headers={"x-request-id": request_id})
310+
311+
def __aiter__(self):
312+
async def gen():
313+
yield ResponseCompletedEvent(
314+
type="response.completed",
315+
response=get_response_obj([get_text_message("first_test")]),
316+
sequence_number=0,
317+
)
318+
319+
return gen()
320+
321+
class DummyResponses:
322+
async def create(self, **kwargs):
323+
assert kwargs.get("stream") is True
324+
return DummyStream()
325+
326+
class DummyResponsesClient:
327+
def __init__(self) -> None:
328+
self.responses = DummyResponses()
329+
330+
model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyResponsesClient()) # type: ignore[arg-type]
331+
332+
agent = Agent(
333+
name="test_agent",
334+
model=model,
335+
)
336+
337+
result = Runner.run_streamed(agent, input="first_test")
338+
async for _ in result.stream_events():
339+
pass
340+
341+
response_spans = [
342+
span
343+
for span in fetch_ordered_spans()
344+
if isinstance(span.span_data, ResponseSpanData) and span.span_data.response is not None
345+
]
346+
347+
assert response_spans
348+
assert any(
349+
getattr(span.span_data.response, "_request_id", None) == request_id
350+
for span in response_spans
351+
)
352+
353+
295354
@pytest.mark.asyncio
296355
async def test_multiple_streamed_runs_are_multiple_traces():
297356
model = FakeModel()

0 commit comments

Comments
 (0)