Skip to content

Commit 3d617d1

Browse files
authored
Merge pull request #958 from parea-ai/revert-957-fix-nested-cattrs
Revert "Avoid nested cattrs check on TraceLogTree"
2 parents f808bd2 + 9c0cc4d commit 3d617d1

File tree

4 files changed

+1285
-1216
lines changed

4 files changed

+1285
-1216
lines changed

parea/client.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
ListExperimentUUIDsFilters,
3333
ProjectSchema,
3434
TestCaseCollection,
35-
TraceLog,
3635
TraceLogFilters,
3736
TraceLogTree,
3837
UseDeployedPrompt,
@@ -403,32 +402,32 @@ def _update_data_and_trace(self, data: Completion) -> Completion:
403402

404403
return data
405404

406-
def get_trace_log(self, trace_id: str, return_children: bool = True) -> Union[TraceLogTree, TraceLog]:
405+
def get_trace_log(self, trace_id: str) -> TraceLogTree:
407406
response = self._client.request("GET", GET_TRACE_LOG_ENDPOINT.format(trace_id=trace_id))
408-
return structure_trace_log_from_api(response.json(), return_children)
407+
return structure_trace_log_from_api(response.json())
409408

410-
def get_trace_log_scores(self, trace_id: str, check_context: bool = True, return_children: bool = True) -> List[EvaluationResult]:
409+
def get_trace_log_scores(self, trace_id: str, check_context: bool = True) -> List[EvaluationResult]:
411410
"""
412411
Get the scores from the trace log. If the scores are not present in the trace log, fetch them from the DB.
413412
Args:
414413
trace_id: The trace id to get the scores for.
415414
check_context: If True, will check the context for the scores first before fetching from the DB.
416-
return_children: If True, will return the children logs in the tree structure.
417415
418416
Returns: A list of EvaluationResult objects.
419417
"""
420418
# try to get trace_id scores from context
421419
if check_context:
422420
if scores := (trace_data.get()[trace_id].scores or []):
421+
print("Scores from context", scores)
423422
return scores
424423

425424
response = self._client.request("GET", GET_TRACE_LOG_ENDPOINT.format(trace_id=trace_id))
426-
tree = structure_trace_log_from_api(response.json(), return_children)
425+
tree: TraceLogTree = structure_trace_log_from_api(response.json())
427426
return extract_scores(tree)
428427

429-
async def aget_trace_log(self, trace_id: str, return_children: bool = True) -> Union[TraceLogTree, TraceLog]:
428+
async def aget_trace_log(self, trace_id: str) -> TraceLogTree:
430429
response = await self._client.request_async("GET", GET_TRACE_LOG_ENDPOINT.format(trace_id=trace_id))
431-
return structure_trace_log_from_api(response.json(), return_children)
430+
return structure_trace_log_from_api(response.json())
432431

433432
def list_experiments(self, filter_conditions: Optional[ListExperimentUUIDsFilters] = ListExperimentUUIDsFilters()) -> List[ExperimentWithPinnedStatsSchema]:
434433
response = self._client.request("POST", LIST_EXPERIMENTS_ENDPOINT, data=asdict(filter_conditions))
@@ -438,15 +437,13 @@ async def alist_experiments(self, filter_conditions: Optional[ListExperimentUUID
438437
response = await self._client.request_async("POST", LIST_EXPERIMENTS_ENDPOINT, data=asdict(filter_conditions))
439438
return structure(response.json(), List[ExperimentWithPinnedStatsSchema])
440439

441-
def get_experiment_trace_logs(self, experiment_uuid: str, filters: TraceLogFilters = TraceLogFilters(), return_children: bool = False) -> List[Union[TraceLogTree, TraceLog]]:
440+
def get_experiment_trace_logs(self, experiment_uuid: str, filters: TraceLogFilters = TraceLogFilters()) -> List[TraceLogTree]:
442441
response = self._client.request("POST", GET_EXPERIMENT_LOGS_ENDPOINT.format(experiment_uuid=experiment_uuid), data=asdict(filters))
443-
return structure_trace_logs_from_api(response.json(), return_children)
442+
return structure_trace_logs_from_api(response.json())
444443

445-
async def aget_experiment_trace_logs(
446-
self, experiment_uuid: str, filters: TraceLogFilters = TraceLogFilters(), return_children: bool = False
447-
) -> List[Union[TraceLogTree, TraceLog]]:
444+
async def aget_experiment_trace_logs(self, experiment_uuid: str, filters: TraceLogFilters = TraceLogFilters()) -> List[TraceLogTree]:
448445
response = await self._client.request_async("POST", GET_EXPERIMENT_LOGS_ENDPOINT.format(experiment_uuid=experiment_uuid), data=asdict(filters))
449-
return structure_trace_logs_from_api(response.json(), return_children)
446+
return structure_trace_logs_from_api(response.json())
450447

451448
def get_experiment(self, experiment_uuid: str) -> Optional[ExperimentWithPinnedStatsSchema]:
452449
filter_conditions = ListExperimentUUIDsFilters(experiment_uuids=[experiment_uuid])
@@ -475,15 +472,14 @@ def new_init(self, *args, **kwargs):
475472
return subclass
476473

477474

478-
def extract_scores(tree: Union[TraceLogTree, TraceLog]) -> List[EvaluationResult]:
475+
def extract_scores(tree: TraceLogTree) -> List[EvaluationResult]:
479476
scores: List[EvaluationResult] = []
480477

481-
def traverse(node: Union[TraceLogTree, TraceLog]):
478+
def traverse(node: TraceLogTree):
482479
if node.scores:
483480
scores.extend(node.scores or [])
484-
if isinstance(node, TraceLogTree):
485-
for child in node.children_logs or []:
486-
traverse(child)
481+
for child in node.children_logs:
482+
traverse(child)
487483

488484
traverse(tree)
489485
return scores

parea/helpers.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def timezone_aware_now() -> datetime:
8181
return datetime.now(pytz.utc)
8282

8383

84-
def structure_trace_log_from_api(d: dict, return_children: bool = False) -> Union[TraceLogTree, TraceLog]:
84+
def structure_trace_log_from_api(d: dict) -> TraceLogTree:
8585
def structure_union_type(obj: Any, cl: type) -> Any:
8686
if isinstance(obj, str):
8787
return obj
@@ -92,13 +92,11 @@ def structure_union_type(obj: Any, cl: type) -> Any:
9292

9393
converter = GenConverter()
9494
converter.register_structure_hook(Union[str, Dict[str, str], None], structure_union_type)
95-
if return_children:
96-
return converter.structure(d, TraceLogTree)
97-
return converter.structure(d, TraceLog)
95+
return converter.structure(d, TraceLogTree)
9896

9997

100-
def structure_trace_logs_from_api(data: List[dict], return_children: bool = False) -> List[Union[TraceLogTree, TraceLog]]:
101-
return [structure_trace_log_from_api(d, return_children) for d in data]
98+
def structure_trace_logs_from_api(data: List[dict]) -> List[TraceLogTree]:
99+
return [structure_trace_log_from_api(d) for d in data]
102100

103101

104102
PAREA_LOGGING_DISABLED = False

0 commit comments

Comments
 (0)