|
15 | 15 | from typing import TYPE_CHECKING |
16 | 16 |
|
17 | 17 | if TYPE_CHECKING: |
18 | | - from typing import Any, Iterable, List, Optional, Callable, Iterator |
| 18 | + from typing import Any, Iterable, List, Optional, Callable, AsyncIterator, Iterator |
19 | 19 | from sentry_sdk.tracing import Span |
20 | 20 |
|
21 | 21 | try: |
@@ -165,7 +165,7 @@ def _new_chat_completion_common(f, *args, **kwargs): |
165 | 165 | elif hasattr(res, "_iterator"): |
166 | 166 | data_buf: list[list[str]] = [] # one for each choice |
167 | 167 |
|
168 | | - old_iterator = res._iterator # type: Iterator[ChatCompletionChunk] |
| 168 | + old_iterator = res._iterator |
169 | 169 |
|
170 | 170 | def new_iterator(): |
171 | 171 | # type: () -> Iterator[ChatCompletionChunk] |
@@ -200,7 +200,44 @@ def new_iterator(): |
200 | 200 | ) |
201 | 201 | span.__exit__(None, None, None) |
202 | 202 |
|
203 | | - res._iterator = new_iterator() |
| 203 | + async def new_iterator_async(): |
| 204 | + # type: () -> AsyncIterator[ChatCompletionChunk] |
| 205 | + with capture_internal_exceptions(): |
| 206 | + async for x in old_iterator: |
| 207 | + if hasattr(x, "choices"): |
| 208 | + choice_index = 0 |
| 209 | + for choice in x.choices: |
| 210 | + if hasattr(choice, "delta") and hasattr( |
| 211 | + choice.delta, "content" |
| 212 | + ): |
| 213 | + content = choice.delta.content |
| 214 | + if len(data_buf) <= choice_index: |
| 215 | + data_buf.append([]) |
| 216 | + data_buf[choice_index].append(content or "") |
| 217 | + choice_index += 1 |
| 218 | + yield x |
| 219 | + if len(data_buf) > 0: |
| 220 | + all_responses = list( |
| 221 | + map(lambda chunk: "".join(chunk), data_buf) |
| 222 | + ) |
| 223 | + if should_send_default_pii() and integration.include_prompts: |
| 224 | + set_data_normalized( |
| 225 | + span, SPANDATA.AI_RESPONSES, all_responses |
| 226 | + ) |
| 227 | + _calculate_chat_completion_usage( |
| 228 | + messages, |
| 229 | + res, |
| 230 | + span, |
| 231 | + all_responses, |
| 232 | + integration.count_tokens, |
| 233 | + ) |
| 234 | + span.__exit__(None, None, None) |
| 235 | + |
| 236 | + if str(type(res._iterator)) == "<class 'async_generator'>": |
| 237 | + res._iterator = new_iterator_async() |
| 238 | + else: |
| 239 | + res._iterator = new_iterator() |
| 240 | + |
204 | 241 | else: |
205 | 242 | set_data_normalized(span, "unknown_response", True) |
206 | 243 | span.__exit__(None, None, None) |
|
0 commit comments