Skip to content

Commit 9a72f96

Browse files
committed
feat: update instructor streaming integration
1 parent 6472419 commit 9a72f96

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

parea/utils/trace_integrations/instructor.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Any, Callable, Mapping, Tuple
1+
import logging
2+
from typing import Any, Callable, Mapping, Tuple, List
23

34
import contextvars
45
from json import JSONDecodeError
@@ -12,6 +13,9 @@
1213
from parea.schemas import EvaluationResult, UpdateLog
1314
from parea.utils.trace_integrations.wrapt_utils import CopyableFunctionWrapper
1415
from parea.utils.trace_utils import logger_update_record, trace_data, trace_insert
16+
from parea.utils.universal_encoder import json_dumps
17+
18+
logger = logging.getLogger()
1519

1620
instructor_trace_id = contextvars.ContextVar("instructor_trace_id", default="")
1721
instructor_val_err_count = contextvars.ContextVar("instructor_val_err_count", default=0)
@@ -50,14 +54,11 @@ def report_instructor_validation_errors() -> None:
5054
score=instructor_val_err_count.get(),
5155
reason=reason,
5256
)
53-
last_child_trace_id = trace_data.get()[instructor_trace_id.get()].children[-1]
54-
trace_insert(
55-
{
56-
"scores": [instructor_score],
57-
"configuration": trace_data.get()[last_child_trace_id].configuration,
58-
},
59-
instructor_trace_id.get(),
60-
)
57+
trace_update_dict = {"scores": [instructor_score]}
58+
if children := trace_data.get()[instructor_trace_id.get()].children:
59+
last_child_trace_id = children[-1]
60+
trace_update_dict["configuration"] = trace_data.get()[last_child_trace_id].configuration
61+
trace_insert(trace_update_dict, instructor_trace_id.get())
6162
instructor_trace_id.set("")
6263
instructor_val_err_count.set(0)
6364
instructor_val_errs.set([])
@@ -82,11 +83,20 @@ def __call__(
8283
trace_name = "instructor"
8384
if "response_model" in kwargs and kwargs["response_model"] and hasattr(kwargs["response_model"], "__name__"):
8485
trace_name = kwargs["response_model"].__name__
86+
87+
def fn_transform_generator_outputs(items: List) -> str:
88+
try:
89+
return json_dumps(items[-1])
90+
except Exception as e:
91+
logger.warning(f"Failed to serialize generator output: {e}", exc_info=e)
92+
return ""
93+
8594
return trace(
8695
name=trace_name,
8796
overwrite_trace_id=trace_id,
8897
overwrite_inputs=inputs,
8998
metadata=metadata,
99+
fn_transform_generator_outputs=fn_transform_generator_outputs
90100
)(
91101
wrapped
92102
)(*args, **kwargs)

0 commit comments

Comments
 (0)