11from __future__ import annotations
22
3+ from contextlib import contextmanager
34from typing import (
45 TYPE_CHECKING ,
56 Any ,
@@ -162,7 +163,7 @@ def __call__(self, _fn: ValueFn[IT]) -> Self:
162163 raise TypeError ("Value function must be callable" )
163164
164165 # Set value function with extra meta information
165- self .fn = AsyncValueFn (_fn )
166+ self .fn = AsyncValueFn (_fn , self )
166167
167168 # Copy over function name as it is consistent with how Session and Output
168169 # retrieve function names
@@ -350,6 +351,7 @@ class AsyncValueFn(Generic[IT]):
350351 def __init__ (
351352 self ,
352353 fn : Callable [[], IT | None ] | Callable [[], Awaitable [IT | None ]],
354+ renderer : Renderer [Any ],
353355 ):
354356 if isinstance (fn , AsyncValueFn ):
355357 raise TypeError (
@@ -358,12 +360,14 @@ def __init__(
358360 self ._is_async = is_async_callable (fn )
359361 self ._fn = wrap_async (fn )
360362 self ._orig_fn = fn
363+ self ._renderer = renderer
361364
362365 async def __call__ (self ) -> IT | None :
363366 """
364367 Call the asynchronous function.
365368 """
366- return await self ._fn ()
369+ with self ._current_output_id ():
370+ return await self ._fn ()
367371
368372 def is_async (self ) -> bool :
369373 """
@@ -404,3 +408,13 @@ def get_sync_fn(self) -> Callable[[], IT | None]:
404408 )
405409 sync_fn = cast (Callable [[], IT ], self ._orig_fn )
406410 return sync_fn
411+
412+ @contextmanager
413+ def _current_output_id (self ):
414+ from ...session import get_current_session
415+
416+ session = get_current_session ()
417+ if session is not None :
418+ session .current_output_id = self ._renderer .output_id
419+ yield
420+ session .current_output_id = None
0 commit comments