@@ -166,6 +166,7 @@ def trace(
166166 overwrite_trace_id : Optional [str ] = None ,
167167 overwrite_inputs : Optional [Dict [str , Any ]] = None ,
168168 log_sample_rate : Optional [float ] = 1.0 ,
169+ fn_transform_generator_outputs : Callable [[List [Any ]], str ] = None ,
169170):
170171 def init_trace (func_name , _parea_target_field , args , kwargs , func ) -> Tuple [str , datetime , contextvars .Token ]:
171172 start_time = timezone_aware_now ()
@@ -258,24 +259,60 @@ def cleanup_trace(trace_id: str, start_time: datetime, context_token: contextvar
258259 thread_eval_funcs_then_log (trace_id , eval_funcs )
259260 trace_context .reset (context_token )
260261
262+ def _handle_iterator_cleanup (items , trace_id , start_time , context_token ):
263+ if fn_transform_generator_outputs :
264+ result = fn_transform_generator_outputs (items )
265+ elif all (isinstance (item , str ) for item in items ):
266+ result = "" .join (items )
267+ else :
268+ result = ""
269+ if not is_logging_disabled () and not log_omit_outputs :
270+ fill_trace_data (trace_id , {"result" : result }, UpdateTraceScenario .RESULT )
271+
272+ cleanup_trace (trace_id , start_time , context_token )
273+
274+ async def _wrap_async_iterator (iterator , trace_id , start_time , context_token ):
275+ items = []
276+ try :
277+ async for item in iterator :
278+ items .append (item )
279+ yield item
280+ finally :
281+ _handle_iterator_cleanup (items , trace_id , start_time , context_token )
282+
283+ def _wrap_sync_iterator (iterator , trace_id , start_time , context_token ):
284+ items = []
285+ try :
286+ for item in iterator :
287+ items .append (item )
288+ yield item
289+ finally :
290+ _handle_iterator_cleanup (items , trace_id , start_time , context_token )
291+
261292 def decorator (func ):
262293 @wraps (func )
263294 async def async_wrapper (* args , ** kwargs ):
264295 _parea_target_field = kwargs .pop ("_parea_target_field" , None )
265296 trace_id , start_time , context_token = init_trace (func .__name__ , _parea_target_field , args , kwargs , func )
266297 output_as_list = check_multiple_return_values (func )
298+ result = None
267299 try :
268300 result = await func (* args , ** kwargs )
269301 if not is_logging_disabled () and not log_omit_outputs :
270302 fill_trace_data (trace_id , {"result" : result , "output_as_list" : output_as_list , "eval_funcs_names" : eval_funcs_names }, UpdateTraceScenario .RESULT )
271- return result
272303 except Exception as e :
273304 logger .error (f"Error occurred in function { func .__name__ } , { e } " )
274305 fill_trace_data (trace_id , {"error" : traceback .format_exc ()}, UpdateTraceScenario .ERROR )
275306 raise e
276307 finally :
277308 try :
278- cleanup_trace (trace_id , start_time , context_token )
309+ if inspect .isasyncgen (result ):
310+ return _wrap_async_iterator (result , trace_id , start_time , context_token )
311+ else :
312+ cleanup_trace (trace_id , start_time , context_token )
313+ # to not swallow any exceptions
314+ if result is not None :
315+ return result
279316 except Exception as e :
280317 logger .debug (f"Error occurred cleaning up trace for function { func .__name__ } , { e } " , exc_info = e )
281318
@@ -284,18 +321,24 @@ def wrapper(*args, **kwargs):
284321 _parea_target_field = kwargs .pop ("_parea_target_field" , None )
285322 trace_id , start_time , context_token = init_trace (func .__name__ , _parea_target_field , args , kwargs , func )
286323 output_as_list = check_multiple_return_values (func )
324+ result = None
287325 try :
288326 result = func (* args , ** kwargs )
289327 if not is_logging_disabled () and not log_omit_outputs :
290328 fill_trace_data (trace_id , {"result" : result , "output_as_list" : output_as_list , "eval_funcs_names" : eval_funcs_names }, UpdateTraceScenario .RESULT )
291- return result
292329 except Exception as e :
293330 logger .error (f"Error occurred in function { func .__name__ } , { e } " )
294331 fill_trace_data (trace_id , {"error" : traceback .format_exc ()}, UpdateTraceScenario .ERROR )
295332 raise e
296333 finally :
297334 try :
298- cleanup_trace (trace_id , start_time , context_token )
335+ if inspect .isgenerator (result ):
336+ return _wrap_sync_iterator (result , trace_id , start_time , context_token )
337+ else :
338+ cleanup_trace (trace_id , start_time , context_token )
339+ # to not swallow any exceptions
340+ if result is not None :
341+ return result
299342 except Exception as e :
300343 logger .debug (f"Error occurred cleaning up trace for function { func .__name__ } , { e } " , exc_info = e )
301344
0 commit comments