3232 ListExperimentUUIDsFilters ,
3333 ProjectSchema ,
3434 TestCaseCollection ,
35- TraceLog ,
3635 TraceLogFilters ,
3736 TraceLogTree ,
3837 UseDeployedPrompt ,
@@ -403,32 +402,32 @@ def _update_data_and_trace(self, data: Completion) -> Completion:
403402
404403 return data
405404
406- def get_trace_log (self , trace_id : str , return_children : bool = True ) -> Union [ TraceLogTree , TraceLog ] :
405+ def get_trace_log (self , trace_id : str ) -> TraceLogTree :
407406 response = self ._client .request ("GET" , GET_TRACE_LOG_ENDPOINT .format (trace_id = trace_id ))
408- return structure_trace_log_from_api (response .json (), return_children )
407+ return structure_trace_log_from_api (response .json ())
409408
410- def get_trace_log_scores (self , trace_id : str , check_context : bool = True , return_children : bool = True ) -> List [EvaluationResult ]:
409+ def get_trace_log_scores (self , trace_id : str , check_context : bool = True ) -> List [EvaluationResult ]:
411410 """
412411 Get the scores from the trace log. If the scores are not present in the trace log, fetch them from the DB.
413412 Args:
414413 trace_id: The trace id to get the scores for.
415414 check_context: If True, will check the context for the scores first before fetching from the DB.
416- return_children: If True, will return the children logs in the tree structure.
417415
418416 Returns: A list of EvaluationResult objects.
419417 """
420418 # try to get trace_id scores from context
421419 if check_context :
422420 if scores := (trace_data .get ()[trace_id ].scores or []):
421+ print ("Scores from context" , scores )
423422 return scores
424423
425424 response = self ._client .request ("GET" , GET_TRACE_LOG_ENDPOINT .format (trace_id = trace_id ))
426- tree = structure_trace_log_from_api (response .json (), return_children )
425+ tree : TraceLogTree = structure_trace_log_from_api (response .json ())
427426 return extract_scores (tree )
428427
429- async def aget_trace_log (self , trace_id : str , return_children : bool = True ) -> Union [ TraceLogTree , TraceLog ] :
428+ async def aget_trace_log (self , trace_id : str ) -> TraceLogTree :
430429 response = await self ._client .request_async ("GET" , GET_TRACE_LOG_ENDPOINT .format (trace_id = trace_id ))
431- return structure_trace_log_from_api (response .json (), return_children )
430+ return structure_trace_log_from_api (response .json ())
432431
433432 def list_experiments (self , filter_conditions : Optional [ListExperimentUUIDsFilters ] = ListExperimentUUIDsFilters ()) -> List [ExperimentWithPinnedStatsSchema ]:
434433 response = self ._client .request ("POST" , LIST_EXPERIMENTS_ENDPOINT , data = asdict (filter_conditions ))
@@ -438,15 +437,13 @@ async def alist_experiments(self, filter_conditions: Optional[ListExperimentUUID
438437 response = await self ._client .request_async ("POST" , LIST_EXPERIMENTS_ENDPOINT , data = asdict (filter_conditions ))
439438 return structure (response .json (), List [ExperimentWithPinnedStatsSchema ])
440439
441- def get_experiment_trace_logs (self , experiment_uuid : str , filters : TraceLogFilters = TraceLogFilters (), return_children : bool = False ) -> List [Union [ TraceLogTree , TraceLog ] ]:
440+ def get_experiment_trace_logs (self , experiment_uuid : str , filters : TraceLogFilters = TraceLogFilters ()) -> List [TraceLogTree ]:
442441 response = self ._client .request ("POST" , GET_EXPERIMENT_LOGS_ENDPOINT .format (experiment_uuid = experiment_uuid ), data = asdict (filters ))
443- return structure_trace_logs_from_api (response .json (), return_children )
442+ return structure_trace_logs_from_api (response .json ())
444443
445- async def aget_experiment_trace_logs (
446- self , experiment_uuid : str , filters : TraceLogFilters = TraceLogFilters (), return_children : bool = False
447- ) -> List [Union [TraceLogTree , TraceLog ]]:
444+ async def aget_experiment_trace_logs (self , experiment_uuid : str , filters : TraceLogFilters = TraceLogFilters ()) -> List [TraceLogTree ]:
448445 response = await self ._client .request_async ("POST" , GET_EXPERIMENT_LOGS_ENDPOINT .format (experiment_uuid = experiment_uuid ), data = asdict (filters ))
449- return structure_trace_logs_from_api (response .json (), return_children )
446+ return structure_trace_logs_from_api (response .json ())
450447
451448 def get_experiment (self , experiment_uuid : str ) -> Optional [ExperimentWithPinnedStatsSchema ]:
452449 filter_conditions = ListExperimentUUIDsFilters (experiment_uuids = [experiment_uuid ])
@@ -475,15 +472,14 @@ def new_init(self, *args, **kwargs):
475472 return subclass
476473
477474
478- def extract_scores (tree : Union [ TraceLogTree , TraceLog ] ) -> List [EvaluationResult ]:
475+ def extract_scores (tree : TraceLogTree ) -> List [EvaluationResult ]:
479476 scores : List [EvaluationResult ] = []
480477
481- def traverse (node : Union [ TraceLogTree , TraceLog ] ):
478+ def traverse (node : TraceLogTree ):
482479 if node .scores :
483480 scores .extend (node .scores or [])
484- if isinstance (node , TraceLogTree ):
485- for child in node .children_logs or []:
486- traverse (child )
481+ for child in node .children_logs :
482+ traverse (child )
487483
488484 traverse (tree )
489485 return scores
0 commit comments