Skip to content

Commit f808bd2

Browse files
authored
Merge pull request #957 from parea-ai/fix-nested-cattrs
Avoid nested cattrs check on TraceLogTree
2 parents b0741ad + 363f969 commit f808bd2

File tree

4 files changed

+1216
-1285
lines changed

4 files changed

+1216
-1285
lines changed

parea/client.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
ListExperimentUUIDsFilters,
3333
ProjectSchema,
3434
TestCaseCollection,
35+
TraceLog,
3536
TraceLogFilters,
3637
TraceLogTree,
3738
UseDeployedPrompt,
@@ -402,32 +403,32 @@ def _update_data_and_trace(self, data: Completion) -> Completion:
402403

403404
return data
404405

405-
def get_trace_log(self, trace_id: str) -> TraceLogTree:
406+
def get_trace_log(self, trace_id: str, return_children: bool = True) -> Union[TraceLogTree, TraceLog]:
406407
response = self._client.request("GET", GET_TRACE_LOG_ENDPOINT.format(trace_id=trace_id))
407-
return structure_trace_log_from_api(response.json())
408+
return structure_trace_log_from_api(response.json(), return_children)
408409

409-
def get_trace_log_scores(self, trace_id: str, check_context: bool = True) -> List[EvaluationResult]:
410+
def get_trace_log_scores(self, trace_id: str, check_context: bool = True, return_children: bool = True) -> List[EvaluationResult]:
410411
"""
411412
Get the scores from the trace log. If the scores are not present in the trace log, fetch them from the DB.
412413
Args:
413414
trace_id: The trace id to get the scores for.
414415
check_context: If True, will check the context for the scores first before fetching from the DB.
416+
return_children: If True, will return the children logs in the tree structure.
415417
416418
Returns: A list of EvaluationResult objects.
417419
"""
418420
# try to get trace_id scores from context
419421
if check_context:
420422
if scores := (trace_data.get()[trace_id].scores or []):
421-
print("Scores from context", scores)
422423
return scores
423424

424425
response = self._client.request("GET", GET_TRACE_LOG_ENDPOINT.format(trace_id=trace_id))
425-
tree: TraceLogTree = structure_trace_log_from_api(response.json())
426+
tree = structure_trace_log_from_api(response.json(), return_children)
426427
return extract_scores(tree)
427428

428-
async def aget_trace_log(self, trace_id: str) -> TraceLogTree:
429+
async def aget_trace_log(self, trace_id: str, return_children: bool = True) -> Union[TraceLogTree, TraceLog]:
429430
response = await self._client.request_async("GET", GET_TRACE_LOG_ENDPOINT.format(trace_id=trace_id))
430-
return structure_trace_log_from_api(response.json())
431+
return structure_trace_log_from_api(response.json(), return_children)
431432

432433
def list_experiments(self, filter_conditions: Optional[ListExperimentUUIDsFilters] = ListExperimentUUIDsFilters()) -> List[ExperimentWithPinnedStatsSchema]:
433434
response = self._client.request("POST", LIST_EXPERIMENTS_ENDPOINT, data=asdict(filter_conditions))
@@ -437,13 +438,15 @@ async def alist_experiments(self, filter_conditions: Optional[ListExperimentUUID
437438
response = await self._client.request_async("POST", LIST_EXPERIMENTS_ENDPOINT, data=asdict(filter_conditions))
438439
return structure(response.json(), List[ExperimentWithPinnedStatsSchema])
439440

440-
def get_experiment_trace_logs(self, experiment_uuid: str, filters: TraceLogFilters = TraceLogFilters()) -> List[TraceLogTree]:
441+
def get_experiment_trace_logs(self, experiment_uuid: str, filters: TraceLogFilters = TraceLogFilters(), return_children: bool = False) -> List[Union[TraceLogTree, TraceLog]]:
441442
response = self._client.request("POST", GET_EXPERIMENT_LOGS_ENDPOINT.format(experiment_uuid=experiment_uuid), data=asdict(filters))
442-
return structure_trace_logs_from_api(response.json())
443+
return structure_trace_logs_from_api(response.json(), return_children)
443444

444-
async def aget_experiment_trace_logs(self, experiment_uuid: str, filters: TraceLogFilters = TraceLogFilters()) -> List[TraceLogTree]:
445+
async def aget_experiment_trace_logs(
446+
self, experiment_uuid: str, filters: TraceLogFilters = TraceLogFilters(), return_children: bool = False
447+
) -> List[Union[TraceLogTree, TraceLog]]:
445448
response = await self._client.request_async("POST", GET_EXPERIMENT_LOGS_ENDPOINT.format(experiment_uuid=experiment_uuid), data=asdict(filters))
446-
return structure_trace_logs_from_api(response.json())
449+
return structure_trace_logs_from_api(response.json(), return_children)
447450

448451
def get_experiment(self, experiment_uuid: str) -> Optional[ExperimentWithPinnedStatsSchema]:
449452
filter_conditions = ListExperimentUUIDsFilters(experiment_uuids=[experiment_uuid])
@@ -472,14 +475,15 @@ def new_init(self, *args, **kwargs):
472475
return subclass
473476

474477

475-
def extract_scores(tree: TraceLogTree) -> List[EvaluationResult]:
478+
def extract_scores(tree: Union[TraceLogTree, TraceLog]) -> List[EvaluationResult]:
476479
scores: List[EvaluationResult] = []
477480

478-
def traverse(node: TraceLogTree):
481+
def traverse(node: Union[TraceLogTree, TraceLog]):
479482
if node.scores:
480483
scores.extend(node.scores or [])
481-
for child in node.children_logs:
482-
traverse(child)
484+
if isinstance(node, TraceLogTree):
485+
for child in node.children_logs or []:
486+
traverse(child)
483487

484488
traverse(tree)
485489
return scores

parea/helpers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def timezone_aware_now() -> datetime:
8181
return datetime.now(pytz.utc)
8282

8383

84-
def structure_trace_log_from_api(d: dict) -> TraceLogTree:
84+
def structure_trace_log_from_api(d: dict, return_children: bool = False) -> Union[TraceLogTree, TraceLog]:
8585
def structure_union_type(obj: Any, cl: type) -> Any:
8686
if isinstance(obj, str):
8787
return obj
@@ -92,11 +92,13 @@ def structure_union_type(obj: Any, cl: type) -> Any:
9292

9393
converter = GenConverter()
9494
converter.register_structure_hook(Union[str, Dict[str, str], None], structure_union_type)
95-
return converter.structure(d, TraceLogTree)
95+
if return_children:
96+
return converter.structure(d, TraceLogTree)
97+
return converter.structure(d, TraceLog)
9698

9799

98-
def structure_trace_logs_from_api(data: List[dict]) -> List[TraceLogTree]:
99-
return [structure_trace_log_from_api(d) for d in data]
100+
def structure_trace_logs_from_api(data: List[dict], return_children: bool = False) -> List[Union[TraceLogTree, TraceLog]]:
101+
return [structure_trace_log_from_api(d, return_children) for d in data]
100102

101103

102104
PAREA_LOGGING_DISABLED = False

0 commit comments

Comments
 (0)