Skip to content

Commit 1769c04

Browse files
fix(openai): async pagination for OpenAI list methods [backport 3.17] (#15067)
Backport 52929aa from #14911 to 3.17. Hey, so this is an attempt to fix #14574 where doing `async for model in client.models.list()` would fail with `TypeError: 'async for' requires an object with __aiter__ method, got coroutine`. ### The Problem Methods like `AsyncModels.list()` and `AsyncFiles.list()` don't actually return coroutines - they return `AsyncPaginator` objects that you can either: - `await` to get the first page (what existing code does) - Use with `async for` to iterate through all items (what was broken) But our wrapper in `_patched_endpoint_async` was converting everything into coroutines, which broke the `async for` use case. ### What I Tried First attempt was using `inspect.iscoroutinefunction()` to detect which methods are actually async vs just returning async objects. That got messy fast because checking unbound methods from classes didn't work reliably. Then I tried just using the sync wrapper for list methods: ```python if method_name == "list": wrap(openai, async_method, _patched_endpoint(openai, endpoint_hook)) ``` This looked promising - the pagination tests passed! But it broke `test_model_alist` and `test_file_alist` because those tests do `await client.models.list()` and expect full tracing with response metadata like `openai.response.count`. Using the sync wrapper meant we lost all that when the paginator was awaited. Also tried returning the paginator directly without any wrapping, but that meant we lost tracing entirely when someone did `async for`. Not acceptable. ### A Solution Created a `_TracedAsyncPaginator` wrapper class that implements both `__aiter__` and `__await__`. This way: - When you do `await client.models.list()` -> calls `__await__`, traces properly, returns first page (existing behavior preserved) - When you do `async for model in client.models.list()` -> calls `__aiter__`, traces on first iteration, yields items (fixes the bug) The wrapper is ~50 lines but it's the minimal solution that preserves 100% backward compatibility while fixing the breaking bug. Had to use `finally` blocks to ensure traces complete even if iteration stops early. ### Testing Added two new pagination tests (`test_model_list_pagination` and `test_model_alist_pagination`) that specifically test the `async for` pattern. Co-authored-by: Alexandre Choura <[email protected]>
1 parent ac57ec7 commit 1769c04

File tree

6 files changed

+200
-25
lines changed

6 files changed

+200
-25
lines changed

ddtrace/contrib/internal/openai/_endpoint_hooks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,8 @@ def _record_response(self, pin, integration, span, args, kwargs, resp, error):
260260
resp = super()._record_response(pin, integration, span, args, kwargs, resp, error)
261261
if not resp:
262262
return
263-
span.set_metric("openai.response.count", len(resp.data or []))
263+
if hasattr(resp, "data"):
264+
span.set_metric("openai.response.count", len(resp.data or []))
264265
return resp
265266

266267

ddtrace/contrib/internal/openai/patch.py

Lines changed: 101 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -286,40 +286,117 @@ def patched_endpoint(openai, pin, func, instance, args, kwargs):
286286
return patched_endpoint(openai)
287287

288288

289+
class _TracedAsyncPaginator:
290+
"""Wrapper for AsyncPaginator objects to enable tracing for both await and async for usage."""
291+
292+
def __init__(self, paginator, pin, integration, patch_hook, instance, args, kwargs):
293+
self._paginator = paginator
294+
self._pin = pin
295+
self._integration = integration
296+
self._patch_hook = patch_hook
297+
self._instance = instance
298+
self._args = args
299+
self._kwargs = kwargs
300+
301+
def __aiter__(self):
302+
async def _traced_aiter():
303+
g = _traced_endpoint(
304+
self._patch_hook, self._integration, self._instance, self._pin, self._args, self._kwargs
305+
)
306+
g.send(None)
307+
err = None
308+
completed = False
309+
try:
310+
iterator = self._paginator.__aiter__()
311+
# Fetch first item to trigger trace completion before iteration starts.
312+
# This ensures the span is recorded even if iteration stops early.
313+
first_item = await iterator.__anext__()
314+
try:
315+
g.send((None, None))
316+
completed = True
317+
except StopIteration:
318+
completed = True
319+
yield first_item
320+
async for item in iterator:
321+
yield item
322+
except StopAsyncIteration:
323+
pass
324+
except BaseException as e:
325+
err = e
326+
raise
327+
finally:
328+
if not completed:
329+
try:
330+
g.send((None, err))
331+
except StopIteration:
332+
pass
333+
334+
return _traced_aiter()
335+
336+
def __await__(self):
337+
async def _trace_and_await():
338+
g = _traced_endpoint(
339+
self._patch_hook, self._integration, self._instance, self._pin, self._args, self._kwargs
340+
)
341+
g.send(None)
342+
resp, err = None, None
343+
try:
344+
resp = await self._paginator
345+
except BaseException as e:
346+
err = e
347+
raise
348+
finally:
349+
try:
350+
g.send((resp, err))
351+
except StopIteration as e:
352+
if err is None:
353+
return e.value
354+
return resp
355+
356+
return _trace_and_await().__await__()
357+
358+
289359
def _patched_endpoint_async(openai, patch_hook):
290-
# Same as _patched_endpoint but async
291360
@with_traced_module
292-
async def patched_endpoint(openai, pin, func, instance, args, kwargs):
361+
def patched_endpoint(openai, pin, func, instance, args, kwargs):
293362
if (
294363
patch_hook is _endpoint_hooks._ChatCompletionWithRawResponseHook
295364
or patch_hook is _endpoint_hooks._CompletionWithRawResponseHook
296365
):
297366
kwargs[OPENAI_WITH_RAW_RESPONSE_ARG] = True
298-
return await func(*args, **kwargs)
367+
return func(*args, **kwargs)
299368
if kwargs.pop(OPENAI_WITH_RAW_RESPONSE_ARG, False) and kwargs.get("stream", False):
300-
return await func(*args, **kwargs)
369+
return func(*args, **kwargs)
301370

302-
integration = openai._datadog_integration
303-
g = _traced_endpoint(patch_hook, integration, instance, pin, args, kwargs)
304-
g.send(None)
305-
resp, err = None, None
306-
override_return = None
307-
try:
308-
resp = await func(*args, **kwargs)
309-
except BaseException as e:
310-
err = e
311-
raise
312-
finally:
371+
result = func(*args, **kwargs)
372+
# Detect AsyncPaginator objects (have both __aiter__ and __await__).
373+
# These must be returned directly (not awaited) to preserve iteration behavior.
374+
if hasattr(result, "__aiter__") and hasattr(result, "__await__"):
375+
return _TracedAsyncPaginator(result, pin, openai._datadog_integration, patch_hook, instance, args, kwargs)
376+
377+
async def async_wrapper():
378+
integration = openai._datadog_integration
379+
g = _traced_endpoint(patch_hook, integration, instance, pin, args, kwargs)
380+
g.send(None)
381+
resp, err = None, None
382+
override_return = None
313383
try:
314-
g.send((resp, err))
315-
except StopIteration as e:
316-
if err is None:
317-
# This return takes priority over `return resp`
318-
override_return = e.value
319-
320-
if override_return is not None:
321-
return override_return
322-
return resp
384+
resp = await result
385+
except BaseException as e:
386+
err = e
387+
raise
388+
finally:
389+
try:
390+
g.send((resp, err))
391+
except StopIteration as e:
392+
if err is None:
393+
override_return = e.value
394+
395+
if override_return is not None:
396+
return override_return
397+
return resp
398+
399+
return async_wrapper()
323400

324401
return patched_endpoint(openai)
325402

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
openai: This fix resolves an issue where using async iteration with paginated methods (e.g., ``async for model in client.models.list()``) caused a ``TypeError: 'async for' requires an object with __aiter__ method, got coroutine``. See `issue #14574 <https://github.com/DataDog/dd-trace-py/issues/14574>`_.

tests/contrib/openai/test_openai_v1.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,22 @@ def test_model_list(api_key_in_env, request_api_key, openai, openai_vcr, snapsho
3535
client.models.list()
3636

3737

38+
@pytest.mark.parametrize("api_key_in_env", [True, False])
39+
def test_model_list_pagination(api_key_in_env, request_api_key, openai, openai_vcr, snapshot_tracer):
40+
with snapshot_context(
41+
token="tests.contrib.openai.test_openai.test_model_list_pagination",
42+
ignores=["meta.http.useragent", "meta.openai.api_type", "meta.openai.api_base", "meta.openai.request.user"],
43+
):
44+
with openai_vcr.use_cassette("model_list.yaml"):
45+
client = openai.OpenAI(api_key=request_api_key)
46+
count = 0
47+
for model in client.models.list():
48+
count += 1
49+
if count >= 2:
50+
break
51+
assert count >= 2
52+
53+
3854
@pytest.mark.parametrize("api_key_in_env", [True, False])
3955
async def test_model_alist(api_key_in_env, request_api_key, openai, openai_vcr, snapshot_tracer):
4056
with snapshot_context(
@@ -46,6 +62,22 @@ async def test_model_alist(api_key_in_env, request_api_key, openai, openai_vcr,
4662
await client.models.list()
4763

4864

65+
@pytest.mark.parametrize("api_key_in_env", [True, False])
66+
async def test_model_alist_pagination(api_key_in_env, request_api_key, openai, openai_vcr, snapshot_tracer):
67+
with snapshot_context(
68+
token="tests.contrib.openai.test_openai.test_model_alist_pagination",
69+
ignores=["meta.http.useragent", "meta.openai.api_type", "meta.openai.api_base", "meta.openai.request.user"],
70+
):
71+
with openai_vcr.use_cassette("model_alist.yaml"):
72+
client = openai.AsyncOpenAI(api_key=request_api_key)
73+
count = 0
74+
async for model in client.models.list():
75+
count += 1
76+
if count >= 2:
77+
break
78+
assert count >= 2
79+
80+
4981
@pytest.mark.parametrize("api_key_in_env", [True, False])
5082
def test_model_retrieve(api_key_in_env, request_api_key, openai, openai_vcr, snapshot_tracer):
5183
with snapshot_context(
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
[[
2+
{
3+
"name": "openai.request",
4+
"service": "tests.contrib.openai",
5+
"resource": "listModels",
6+
"trace_id": 0,
7+
"span_id": 1,
8+
"parent_id": 0,
9+
"type": "",
10+
"error": 0,
11+
"meta": {
12+
"_dd.p.dm": "-0",
13+
"_dd.p.tid": "68f0b1d700000000",
14+
"component": "openai",
15+
"language": "python",
16+
"openai.request.endpoint": "/v1/models",
17+
"openai.request.method": "GET",
18+
"openai.request.provider": "OpenAI",
19+
"runtime-id": "1e2a3154601a494f8f219a4327b659c2"
20+
},
21+
"metrics": {
22+
"_dd.measured": 1,
23+
"_dd.top_level": 1,
24+
"_dd.tracer_kr": 1.0,
25+
"_sampling_priority_v1": 1,
26+
"process_id": 573
27+
},
28+
"duration": 1683125,
29+
"start": 1760604631675824507
30+
}]]
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
[[
2+
{
3+
"name": "openai.request",
4+
"service": "tests.contrib.openai",
5+
"resource": "listModels",
6+
"trace_id": 0,
7+
"span_id": 1,
8+
"parent_id": 0,
9+
"type": "",
10+
"error": 0,
11+
"meta": {
12+
"_dd.p.dm": "-0",
13+
"_dd.p.tid": "68f0b1d500000000",
14+
"component": "openai",
15+
"language": "python",
16+
"openai.request.endpoint": "/v1/models",
17+
"openai.request.method": "GET",
18+
"openai.request.provider": "OpenAI",
19+
"runtime-id": "1e2a3154601a494f8f219a4327b659c2"
20+
},
21+
"metrics": {
22+
"_dd.measured": 1,
23+
"_dd.top_level": 1,
24+
"_dd.tracer_kr": 1.0,
25+
"_sampling_priority_v1": 1,
26+
"openai.response.count": 112,
27+
"process_id": 573
28+
},
29+
"duration": 13777416,
30+
"start": 1760604629974266007
31+
}]]

0 commit comments

Comments
 (0)