Skip to content

Commit b82cf99

Browse files
committed
Added async test cases for openai
1 parent fa46f61 commit b82cf99

File tree

3 files changed

+378
-4
lines changed

3 files changed

+378
-4
lines changed

sentry_sdk/integrations/openai.py

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import TYPE_CHECKING
1616

1717
if TYPE_CHECKING:
18-
from typing import Any, Iterable, List, Optional, Callable, Iterator
18+
from typing import Any, Iterable, List, Optional, Callable, AsyncIterator, Iterator
1919
from sentry_sdk.tracing import Span
2020

2121
try:
@@ -165,7 +165,7 @@ def _new_chat_completion_common(f, *args, **kwargs):
165165
elif hasattr(res, "_iterator"):
166166
data_buf: list[list[str]] = [] # one for each choice
167167

168-
old_iterator = res._iterator # type: Iterator[ChatCompletionChunk]
168+
old_iterator = res._iterator
169169

170170
def new_iterator():
171171
# type: () -> Iterator[ChatCompletionChunk]
@@ -200,7 +200,44 @@ def new_iterator():
200200
)
201201
span.__exit__(None, None, None)
202202

203-
res._iterator = new_iterator()
203+
async def new_iterator_async():
204+
# type: () -> AsyncIterator[ChatCompletionChunk]
205+
with capture_internal_exceptions():
206+
async for x in old_iterator:
207+
if hasattr(x, "choices"):
208+
choice_index = 0
209+
for choice in x.choices:
210+
if hasattr(choice, "delta") and hasattr(
211+
choice.delta, "content"
212+
):
213+
content = choice.delta.content
214+
if len(data_buf) <= choice_index:
215+
data_buf.append([])
216+
data_buf[choice_index].append(content or "")
217+
choice_index += 1
218+
yield x
219+
if len(data_buf) > 0:
220+
all_responses = list(
221+
map(lambda chunk: "".join(chunk), data_buf)
222+
)
223+
if should_send_default_pii() and integration.include_prompts:
224+
set_data_normalized(
225+
span, SPANDATA.AI_RESPONSES, all_responses
226+
)
227+
_calculate_chat_completion_usage(
228+
messages,
229+
res,
230+
span,
231+
all_responses,
232+
integration.count_tokens,
233+
)
234+
span.__exit__(None, None, None)
235+
236+
if str(type(res._iterator)) == "<class 'async_generator'>":
237+
res._iterator = new_iterator_async()
238+
else:
239+
res._iterator = new_iterator()
240+
204241
else:
205242
set_data_normalized(span, "unknown_response", True)
206243
span.__exit__(None, None, None)

0 commit comments

Comments
 (0)