11import asyncio
2+ import contextvars
23import inspect
34import logging
45import os
1011 Dict ,
1112 Generator ,
1213 Iterable ,
14+ List ,
1315 Optional ,
1416 Tuple ,
1517 TypeVar ,
2123from opentelemetry .util ._decorator import _AgnosticContextManager
2224from typing_extensions import ParamSpec
2325
24- from langfuse ._client .environment_variables import (
25- LANGFUSE_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED ,
26- )
27-
2826from langfuse ._client .constants import (
2927 ObservationTypeLiteralNoEvent ,
3028 get_observation_types_list ,
3129)
30+ from langfuse ._client .environment_variables import (
31+ LANGFUSE_OBSERVE_DECORATOR_IO_CAPTURE_ENABLED ,
32+ )
3233from langfuse ._client .get_client import _set_current_public_key , get_client
3334from langfuse ._client .span import (
34- LangfuseGeneration ,
35- LangfuseSpan ,
3635 LangfuseAgent ,
37- LangfuseTool ,
3836 LangfuseChain ,
39- LangfuseRetriever ,
40- LangfuseEvaluator ,
4137 LangfuseEmbedding ,
38+ LangfuseEvaluator ,
39+ LangfuseGeneration ,
4240 LangfuseGuardrail ,
41+ LangfuseRetriever ,
42+ LangfuseSpan ,
43+ LangfuseTool ,
4344)
4445from langfuse .types import TraceContext
4546
@@ -468,29 +469,54 @@ def _wrap_sync_generator_result(
468469 generator : Generator ,
469470 transform_to_string : Optional [Callable [[Iterable ], str ]] = None ,
470471 ) -> Any :
471- items = []
472+ preserved_context = contextvars . copy_context ()
472473
473- try :
474- for item in generator :
475- items .append (item )
474+ return _ContextPreservedSyncGeneratorWrapper (
475+ generator ,
476+ preserved_context ,
477+ langfuse_span_or_generation ,
478+ transform_to_string ,
479+ )
480+
481+ def _wrap_async_generator_result (
482+ self ,
483+ langfuse_span_or_generation : Union [
484+ LangfuseSpan ,
485+ LangfuseGeneration ,
486+ LangfuseAgent ,
487+ LangfuseTool ,
488+ LangfuseChain ,
489+ LangfuseRetriever ,
490+ LangfuseEvaluator ,
491+ LangfuseEmbedding ,
492+ LangfuseGuardrail ,
493+ ],
494+ generator : AsyncGenerator ,
495+ transform_to_string : Optional [Callable [[Iterable ], str ]] = None ,
496+ ) -> Any :
497+ preserved_context = contextvars .copy_context ()
476498
477- yield item
499+ return _ContextPreservedAsyncGeneratorWrapper (
500+ generator ,
501+ preserved_context ,
502+ langfuse_span_or_generation ,
503+ transform_to_string ,
504+ )
478505
479- finally :
480- output : Any = items
481506
482- if transform_to_string is not None :
483- output = transform_to_string (items )
507+ _decorator = LangfuseDecorator ()
508+
509+ observe = _decorator .observe
484510
485- elif all (isinstance (item , str ) for item in items ):
486- output = "" .join (items )
487511
488- langfuse_span_or_generation . update ( output = output )
489- langfuse_span_or_generation . end ()
512+ class _ContextPreservedSyncGeneratorWrapper :
513+ """Sync generator wrapper that ensures each iteration runs in preserved context."""
490514
491- async def _wrap_async_generator_result (
515+ def __init__ (
492516 self ,
493- langfuse_span_or_generation : Union [
517+ generator : Generator ,
518+ context : contextvars .Context ,
519+ span : Union [
494520 LangfuseSpan ,
495521 LangfuseGeneration ,
496522 LangfuseAgent ,
@@ -501,30 +527,105 @@ async def _wrap_async_generator_result(
501527 LangfuseEmbedding ,
502528 LangfuseGuardrail ,
503529 ],
504- generator : AsyncGenerator ,
505- transform_to_string : Optional [Callable [[Iterable ], str ]] = None ,
506- ) -> AsyncGenerator :
507- items = []
530+ transform_fn : Optional [Callable [[Iterable ], str ]],
531+ ) -> None :
532+ self .generator = generator
533+ self .context = context
534+ self .items : List [Any ] = []
535+ self .span = span
536+ self .transform_fn = transform_fn
537+
538+ def __iter__ (self ) -> "_ContextPreservedSyncGeneratorWrapper" :
539+ return self
540+
541+ def __next__ (self ) -> Any :
542+ try :
543+ # Run the generator's __next__ in the preserved context
544+ item = self .context .run (next , self .generator )
545+ self .items .append (item )
546+
547+ return item
548+
549+ except StopIteration :
550+ # Handle output and span cleanup when generator is exhausted
551+ output : Any = self .items
552+
553+ if self .transform_fn is not None :
554+ output = self .transform_fn (self .items )
555+
556+ elif all (isinstance (item , str ) for item in self .items ):
557+ output = "" .join (self .items )
558+
559+ self .span .update (output = output ).end ()
560+
561+ raise # Re-raise StopIteration
562+
563+ except Exception as e :
564+ self .span .update (level = "ERROR" , status_message = str (e )).end ()
508565
566+ raise
567+
568+
569+ class _ContextPreservedAsyncGeneratorWrapper :
570+ """Async generator wrapper that ensures each iteration runs in preserved context."""
571+
572+ def __init__ (
573+ self ,
574+ generator : AsyncGenerator ,
575+ context : contextvars .Context ,
576+ span : Union [
577+ LangfuseSpan ,
578+ LangfuseGeneration ,
579+ LangfuseAgent ,
580+ LangfuseTool ,
581+ LangfuseChain ,
582+ LangfuseRetriever ,
583+ LangfuseEvaluator ,
584+ LangfuseEmbedding ,
585+ LangfuseGuardrail ,
586+ ],
587+ transform_fn : Optional [Callable [[Iterable ], str ]],
588+ ) -> None :
589+ self .generator = generator
590+ self .context = context
591+ self .items : List [Any ] = []
592+ self .span = span
593+ self .transform_fn = transform_fn
594+
595+ def __aiter__ (self ) -> "_ContextPreservedAsyncGeneratorWrapper" :
596+ return self
597+
598+ async def __anext__ (self ) -> Any :
509599 try :
510- async for item in generator :
511- items .append (item )
600+ # Run the generator's __anext__ in the preserved context
601+ try :
602+ # Python 3.10+ approach with context parameter
603+ item = await asyncio .create_task (
604+ self .generator .__anext__ (), # type: ignore
605+ context = self .context ,
606+ ) # type: ignore
607+ except TypeError :
608+ # Python < 3.10 fallback - context parameter not supported
609+ item = await self .generator .__anext__ ()
512610
513- yield item
611+ self . items . append ( item )
514612
515- finally :
516- output : Any = items
613+ return item
517614
518- if transform_to_string is not None :
519- output = transform_to_string (items )
615+ except StopAsyncIteration :
616+ # Handle output and span cleanup when generator is exhausted
617+ output : Any = self .items
520618
521- elif all ( isinstance ( item , str ) for item in items ) :
522- output = "" . join ( items )
619+ if self . transform_fn is not None :
620+ output = self . transform_fn ( self . items )
523621
524- langfuse_span_or_generation . update ( output = output )
525- langfuse_span_or_generation . end ( )
622+ elif all ( isinstance ( item , str ) for item in self . items ):
623+ output = "" . join ( self . items )
526624
625+ self .span .update (output = output ).end ()
527626
528- _decorator = LangfuseDecorator ()
627+ raise # Re-raise StopAsyncIteration
628+ except Exception as e :
629+ self .span .update (level = "ERROR" , status_message = str (e )).end ()
529630
530- observe = _decorator . observe
631+ raise
0 commit comments