3232 ListExperimentUUIDsFilters ,
3333 ProjectSchema ,
3434 TestCaseCollection ,
35+ TraceLog ,
3536 TraceLogFilters ,
3637 TraceLogTree ,
3738 UseDeployedPrompt ,
@@ -402,32 +403,32 @@ def _update_data_and_trace(self, data: Completion) -> Completion:
402403
403404 return data
404405
405- def get_trace_log (self , trace_id : str ) -> TraceLogTree :
406+ def get_trace_log (self , trace_id : str , return_children : bool = True ) -> Union [ TraceLogTree , TraceLog ] :
406407 response = self ._client .request ("GET" , GET_TRACE_LOG_ENDPOINT .format (trace_id = trace_id ))
407- return structure_trace_log_from_api (response .json ())
408+ return structure_trace_log_from_api (response .json (), return_children )
408409
409- def get_trace_log_scores (self , trace_id : str , check_context : bool = True ) -> List [EvaluationResult ]:
410+ def get_trace_log_scores (self , trace_id : str , check_context : bool = True , return_children : bool = True ) -> List [EvaluationResult ]:
410411 """
411412 Get the scores from the trace log. If the scores are not present in the trace log, fetch them from the DB.
412413 Args:
413414 trace_id: The trace id to get the scores for.
414415 check_context: If True, will check the context for the scores first before fetching from the DB.
416+ return_children: If True, will return the children logs in the tree structure.
415417
416418 Returns: A list of EvaluationResult objects.
417419 """
418420 # try to get trace_id scores from context
419421 if check_context :
420422 if scores := (trace_data .get ()[trace_id ].scores or []):
421- print ("Scores from context" , scores )
422423 return scores
423424
424425 response = self ._client .request ("GET" , GET_TRACE_LOG_ENDPOINT .format (trace_id = trace_id ))
425- tree : TraceLogTree = structure_trace_log_from_api (response .json ())
426+ tree = structure_trace_log_from_api (response .json (), return_children )
426427 return extract_scores (tree )
427428
428- async def aget_trace_log (self , trace_id : str ) -> TraceLogTree :
429+ async def aget_trace_log (self , trace_id : str , return_children : bool = True ) -> Union [ TraceLogTree , TraceLog ] :
429430 response = await self ._client .request_async ("GET" , GET_TRACE_LOG_ENDPOINT .format (trace_id = trace_id ))
430- return structure_trace_log_from_api (response .json ())
431+ return structure_trace_log_from_api (response .json (), return_children )
431432
432433 def list_experiments (self , filter_conditions : Optional [ListExperimentUUIDsFilters ] = ListExperimentUUIDsFilters ()) -> List [ExperimentWithPinnedStatsSchema ]:
433434 response = self ._client .request ("POST" , LIST_EXPERIMENTS_ENDPOINT , data = asdict (filter_conditions ))
@@ -437,13 +438,15 @@ async def alist_experiments(self, filter_conditions: Optional[ListExperimentUUID
437438 response = await self ._client .request_async ("POST" , LIST_EXPERIMENTS_ENDPOINT , data = asdict (filter_conditions ))
438439 return structure (response .json (), List [ExperimentWithPinnedStatsSchema ])
439440
440- def get_experiment_trace_logs (self , experiment_uuid : str , filters : TraceLogFilters = TraceLogFilters ()) -> List [TraceLogTree ]:
441+ def get_experiment_trace_logs (self , experiment_uuid : str , filters : TraceLogFilters = TraceLogFilters (), return_children : bool = False ) -> List [Union [ TraceLogTree , TraceLog ] ]:
441442 response = self ._client .request ("POST" , GET_EXPERIMENT_LOGS_ENDPOINT .format (experiment_uuid = experiment_uuid ), data = asdict (filters ))
442- return structure_trace_logs_from_api (response .json ())
443+ return structure_trace_logs_from_api (response .json (), return_children )
443444
444- async def aget_experiment_trace_logs (self , experiment_uuid : str , filters : TraceLogFilters = TraceLogFilters ()) -> List [TraceLogTree ]:
445+ async def aget_experiment_trace_logs (
446+ self , experiment_uuid : str , filters : TraceLogFilters = TraceLogFilters (), return_children : bool = False
447+ ) -> List [Union [TraceLogTree , TraceLog ]]:
445448 response = await self ._client .request_async ("POST" , GET_EXPERIMENT_LOGS_ENDPOINT .format (experiment_uuid = experiment_uuid ), data = asdict (filters ))
446- return structure_trace_logs_from_api (response .json ())
449+ return structure_trace_logs_from_api (response .json (), return_children )
447450
448451 def get_experiment (self , experiment_uuid : str ) -> Optional [ExperimentWithPinnedStatsSchema ]:
449452 filter_conditions = ListExperimentUUIDsFilters (experiment_uuids = [experiment_uuid ])
@@ -472,14 +475,15 @@ def new_init(self, *args, **kwargs):
472475 return subclass
473476
474477
475- def extract_scores (tree : TraceLogTree ) -> List [EvaluationResult ]:
478+ def extract_scores (tree : Union [ TraceLogTree , TraceLog ] ) -> List [EvaluationResult ]:
476479 scores : List [EvaluationResult ] = []
477480
478- def traverse (node : TraceLogTree ):
481+ def traverse (node : Union [ TraceLogTree , TraceLog ] ):
479482 if node .scores :
480483 scores .extend (node .scores or [])
481- for child in node .children_logs :
482- traverse (child )
484+ if isinstance (node , TraceLogTree ):
485+ for child in node .children_logs or []:
486+ traverse (child )
483487
484488 traverse (tree )
485489 return scores
0 commit comments