1- from typing import Any , Callable , Mapping , Tuple
1+ import logging
2+ from typing import Any , Callable , Mapping , Tuple , List
23
34import contextvars
45from json import JSONDecodeError
1213from parea .schemas import EvaluationResult , UpdateLog
1314from parea .utils .trace_integrations .wrapt_utils import CopyableFunctionWrapper
1415from parea .utils .trace_utils import logger_update_record , trace_data , trace_insert
16+ from parea .utils .universal_encoder import json_dumps
17+
18+ logger = logging .getLogger ()
1519
1620instructor_trace_id = contextvars .ContextVar ("instructor_trace_id" , default = "" )
1721instructor_val_err_count = contextvars .ContextVar ("instructor_val_err_count" , default = 0 )
@@ -50,14 +54,11 @@ def report_instructor_validation_errors() -> None:
5054 score = instructor_val_err_count .get (),
5155 reason = reason ,
5256 )
53- last_child_trace_id = trace_data .get ()[instructor_trace_id .get ()].children [- 1 ]
54- trace_insert (
55- {
56- "scores" : [instructor_score ],
57- "configuration" : trace_data .get ()[last_child_trace_id ].configuration ,
58- },
59- instructor_trace_id .get (),
60- )
57+ trace_update_dict = {"scores" : [instructor_score ]}
58+ if children := trace_data .get ()[instructor_trace_id .get ()].children :
59+ last_child_trace_id = children [- 1 ]
60+ trace_update_dict ["configuration" ] = trace_data .get ()[last_child_trace_id ].configuration
61+ trace_insert (trace_update_dict , instructor_trace_id .get ())
6162 instructor_trace_id .set ("" )
6263 instructor_val_err_count .set (0 )
6364 instructor_val_errs .set ([])
@@ -82,11 +83,20 @@ def __call__(
8283 trace_name = "instructor"
8384 if "response_model" in kwargs and kwargs ["response_model" ] and hasattr (kwargs ["response_model" ], "__name__" ):
8485 trace_name = kwargs ["response_model" ].__name__
86+
87+ def fn_transform_generator_outputs (items : List ) -> str :
88+ try :
89+ return json_dumps (items [- 1 ])
90+ except Exception as e :
91+ logger .warning (f"Failed to serialize generator output: { e } " , exc_info = e )
92+ return ""
93+
8594 return trace (
8695 name = trace_name ,
8796 overwrite_trace_id = trace_id ,
8897 overwrite_inputs = inputs ,
8998 metadata = metadata ,
99+ fn_transform_generator_outputs = fn_transform_generator_outputs
90100 )(
91101 wrapped
92102 )(* args , ** kwargs )
0 commit comments