Skip to content

Commit 7630423

Browse files
committed
handle iterators not only generators
1 parent 96b2bf8 commit 7630423

File tree

1 file changed

+55
-55
lines changed

1 file changed

+55
-55
lines changed

langfuse/_client/observe.py

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
from functools import wraps
77
from typing import (
88
Any,
9-
AsyncGenerator,
109
Callable,
1110
Dict,
12-
Generator,
1311
Iterable,
1412
List,
1513
Optional,
@@ -19,6 +17,7 @@
1917
cast,
2018
overload,
2119
)
20+
from collections.abc import AsyncIterator, Iterator
2221

2322
from opentelemetry.util._decorator import _AgnosticContextManager
2423
from typing_extensions import ParamSpec
@@ -278,7 +277,7 @@ async def async_wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any:
278277
as_type=as_type or "span",
279278
trace_context=trace_context,
280279
input=input,
281-
end_on_exit=False, # when returning a generator, closing on exit would be to early
280+
end_on_exit=False, # when returning a iterator, closing on exit would be to early
282281
)
283282
if langfuse_client
284283
else None
@@ -288,25 +287,16 @@ async def async_wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any:
288287
return await func(*args, **kwargs)
289288

290289
with context_manager as langfuse_span_or_generation:
291-
is_return_type_generator = False
290+
is_return_type_iterator = False
292291

293292
try:
294293
result = await func(*args, **kwargs)
295294

296295
if capture_output is True:
297-
if inspect.isgenerator(result):
298-
is_return_type_generator = True
296+
if isinstance(result, Iterator):
297+
is_return_type_iterator = True
299298

300-
return self._wrap_sync_generator_result(
301-
langfuse_span_or_generation,
302-
result,
303-
transform_to_string,
304-
)
305-
306-
if inspect.isasyncgen(result):
307-
is_return_type_generator = True
308-
309-
return self._wrap_async_generator_result(
299+
return self._wrap_sync_iterator_result(
310300
langfuse_span_or_generation,
311301
result,
312302
transform_to_string,
@@ -316,15 +306,24 @@ async def async_wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any:
316306
if type(result).__name__ == "StreamingResponse" and hasattr(
317307
result, "body_iterator"
318308
):
319-
is_return_type_generator = True
309+
is_return_type_iterator = True
320310

321311
result.body_iterator = (
322-
self._wrap_async_generator_result(
312+
self._wrap_async_iterator_result(
323313
langfuse_span_or_generation,
324314
result.body_iterator,
325315
transform_to_string,
326316
)
327317
)
318+
319+
if isinstance(result, AsyncIterator):
320+
is_return_type_iterator = True
321+
322+
return self._wrap_async_iterator_result(
323+
langfuse_span_or_generation,
324+
result,
325+
transform_to_string,
326+
)
328327

329328
langfuse_span_or_generation.update(output=result)
330329

@@ -336,7 +335,7 @@ async def async_wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> Any:
336335

337336
raise e
338337
finally:
339-
if not is_return_type_generator:
338+
if not is_return_type_iterator:
340339
langfuse_span_or_generation.end()
341340

342341
return cast(F, async_wrapper)
@@ -396,7 +395,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
396395
as_type=as_type or "span",
397396
trace_context=trace_context,
398397
input=input,
399-
end_on_exit=False, # when returning a generator, closing on exit would be to early
398+
end_on_exit=False, # when returning a iterator, closing on exit would be to early
400399
)
401400
if langfuse_client
402401
else None
@@ -406,25 +405,25 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
406405
return func(*args, **kwargs)
407406

408407
with context_manager as langfuse_span_or_generation:
409-
is_return_type_generator = False
408+
is_return_type_iterator = False
410409

411410
try:
412411
result = func(*args, **kwargs)
413412

414413
if capture_output is True:
415-
if inspect.isgenerator(result):
416-
is_return_type_generator = True
414+
if isinstance(result, Iterator):
415+
is_return_type_iterator = True
417416

418-
return self._wrap_sync_generator_result(
417+
return self._wrap_sync_iterator_result(
419418
langfuse_span_or_generation,
420419
result,
421420
transform_to_string,
422421
)
423422

424-
if inspect.isasyncgen(result):
425-
is_return_type_generator = True
423+
if isinstance(result, AsyncIterator):
424+
is_return_type_iterator = True
426425

427-
return self._wrap_async_generator_result(
426+
return self._wrap_async_iterator_result(
428427
langfuse_span_or_generation,
429428
result,
430429
transform_to_string,
@@ -434,15 +433,16 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
434433
if type(result).__name__ == "StreamingResponse" and hasattr(
435434
result, "body_iterator"
436435
):
437-
is_return_type_generator = True
436+
is_return_type_iterator = True
438437

439438
result.body_iterator = (
440-
self._wrap_async_generator_result(
439+
self._wrap_async_iterator_result(
441440
langfuse_span_or_generation,
442441
result.body_iterator,
443442
transform_to_string,
444443
)
445444
)
445+
446446

447447
langfuse_span_or_generation.update(output=result)
448448

@@ -454,7 +454,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
454454

455455
raise e
456456
finally:
457-
if not is_return_type_generator:
457+
if not is_return_type_iterator:
458458
langfuse_span_or_generation.end()
459459

460460
return cast(F, sync_wrapper)
@@ -481,7 +481,7 @@ def _get_input_from_func_args(
481481
"kwargs": func_kwargs,
482482
}
483483

484-
def _wrap_sync_generator_result(
484+
def _wrap_sync_iterator_result(
485485
self,
486486
langfuse_span_or_generation: Union[
487487
LangfuseSpan,
@@ -494,19 +494,19 @@ def _wrap_sync_generator_result(
494494
LangfuseEmbedding,
495495
LangfuseGuardrail,
496496
],
497-
generator: Generator,
497+
iterator: Iterator,
498498
transform_to_string: Optional[Callable[[Iterable], str]] = None,
499499
) -> Any:
500500
preserved_context = contextvars.copy_context()
501501

502-
return _ContextPreservedSyncGeneratorWrapper(
503-
generator,
502+
return _ContextPreservedSyncIteratorWrapper(
503+
iterator,
504504
preserved_context,
505505
langfuse_span_or_generation,
506506
transform_to_string,
507507
)
508508

509-
def _wrap_async_generator_result(
509+
def _wrap_async_iterator_result(
510510
self,
511511
langfuse_span_or_generation: Union[
512512
LangfuseSpan,
@@ -519,13 +519,13 @@ def _wrap_async_generator_result(
519519
LangfuseEmbedding,
520520
LangfuseGuardrail,
521521
],
522-
generator: AsyncGenerator,
522+
iterator: AsyncIterator,
523523
transform_to_string: Optional[Callable[[Iterable], str]] = None,
524524
) -> Any:
525525
preserved_context = contextvars.copy_context()
526526

527-
return _ContextPreservedAsyncGeneratorWrapper(
528-
generator,
527+
return _ContextPreservedAsyncIteratorWrapper(
528+
iterator,
529529
preserved_context,
530530
langfuse_span_or_generation,
531531
transform_to_string,
@@ -537,12 +537,12 @@ def _wrap_async_generator_result(
537537
observe = _decorator.observe
538538

539539

540-
class _ContextPreservedSyncGeneratorWrapper:
541-
"""Sync generator wrapper that ensures each iteration runs in preserved context."""
540+
class _ContextPreservedSyncIteratorWrapper:
541+
"""Sync iterator wrapper that ensures each iteration runs in preserved context."""
542542

543543
def __init__(
544544
self,
545-
generator: Generator,
545+
iterator: Iterator,
546546
context: contextvars.Context,
547547
span: Union[
548548
LangfuseSpan,
@@ -557,25 +557,25 @@ def __init__(
557557
],
558558
transform_fn: Optional[Callable[[Iterable], str]],
559559
) -> None:
560-
self.generator = generator
560+
self.iterator = iterator
561561
self.context = context
562562
self.items: List[Any] = []
563563
self.span = span
564564
self.transform_fn = transform_fn
565565

566-
def __iter__(self) -> "_ContextPreservedSyncGeneratorWrapper":
566+
def __iter__(self) -> "_ContextPreservedSyncIteratorWrapper":
567567
return self
568568

569569
def __next__(self) -> Any:
570570
try:
571-
# Run the generator's __next__ in the preserved context
572-
item = self.context.run(next, self.generator)
571+
# Run the iterator's __next__ in the preserved context
572+
item = self.context.run(next, self.iterator)
573573
self.items.append(item)
574574

575575
return item
576576

577577
except StopIteration:
578-
# Handle output and span cleanup when generator is exhausted
578+
# Handle output and span cleanup when iterator is exhausted
579579
output: Any = self.items
580580

581581
if self.transform_fn is not None:
@@ -596,12 +596,12 @@ def __next__(self) -> Any:
596596
raise
597597

598598

599-
class _ContextPreservedAsyncGeneratorWrapper:
600-
"""Async generator wrapper that ensures each iteration runs in preserved context."""
599+
class _ContextPreservedAsyncIteratorWrapper:
600+
"""Async iterator wrapper that ensures each iteration runs in preserved context."""
601601

602602
def __init__(
603603
self,
604-
generator: AsyncGenerator,
604+
iterator: AsyncIterator,
605605
context: contextvars.Context,
606606
span: Union[
607607
LangfuseSpan,
@@ -616,34 +616,34 @@ def __init__(
616616
],
617617
transform_fn: Optional[Callable[[Iterable], str]],
618618
) -> None:
619-
self.generator = generator
619+
self.iterator = iterator
620620
self.context = context
621621
self.items: List[Any] = []
622622
self.span = span
623623
self.transform_fn = transform_fn
624624

625-
def __aiter__(self) -> "_ContextPreservedAsyncGeneratorWrapper":
625+
def __aiter__(self) -> "_ContextPreservedAsyncIteratorWrapper":
626626
return self
627627

628628
async def __anext__(self) -> Any:
629629
try:
630-
# Run the generator's __anext__ in the preserved context
630+
# Run the iterator's __anext__ in the preserved context
631631
try:
632632
# Python 3.10+ approach with context parameter
633633
item = await asyncio.create_task(
634-
self.generator.__anext__(), # type: ignore
634+
self.iterator.__anext__(), # type: ignore
635635
context=self.context,
636636
) # type: ignore
637637
except TypeError:
638638
# Python < 3.10 fallback - context parameter not supported
639-
item = await self.generator.__anext__()
639+
item = await self.iterator.__anext__()
640640

641641
self.items.append(item)
642642

643643
return item
644644

645645
except StopAsyncIteration:
646-
# Handle output and span cleanup when generator is exhausted
646+
# Handle output and span cleanup when iterator is exhausted
647647
output: Any = self.items
648648

649649
if self.transform_fn is not None:

0 commit comments

Comments
 (0)