@@ -195,7 +195,7 @@ def record_episode_sample(table_name: str, episode):
195195 table_name (str): logging prefix (e.g. "rollout/sample").
196196 episode (Episode): episode object with filled attributes.
197197 """
198- sample = episode .to_dict ()
198+ sample = episode .to_dict (exclude = [ "ref_logprobs" , "completion" ] )
199199 record_metric (table_name , sample , Reduce .SAMPLE )
200200
201201
@@ -675,9 +675,7 @@ def push(self, metric: Metric) -> None:
675675 for backend in self .per_rank_no_reduce_backends :
676676
677677 if metric .reduction == Reduce .SAMPLE :
678- # Wrap singleton Metric into expected {key: [list_of_dicts]} format
679- sample = {metric .key : [metric .value ]}
680- asyncio .create_task (backend .log_samples (sample , self .global_step ))
678+ asyncio .create_task (backend .log_samples ([metric ], self .global_step ))
681679 else :
682680 backend .log_stream (metric = metric , global_step = self .global_step )
683681
@@ -882,11 +880,12 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
882880 async def finish (self ) -> None :
883881 pass
884882
885- async def log_samples (self , samples : Dict [ str , List [dict ] ], step : int ) -> None :
883+ async def log_samples (self , samples : List [Metric ], step : int ) -> None :
886884 """Pretty-print sample-level logs to console."""
887885
888886 logger .info (f"========== SAMPLE LOGS STEP { step } ==========" )
889- for table_name , table_rows in samples .items ():
887+ for sample in samples :
888+ table_name , table_rows = sample .key , sample .value
890889 logger .info (f"[{ table_name } ] ({ len (table_rows )} samples)" )
891890 logger .info (json .dumps (table_rows , indent = 2 , ensure_ascii = False ))
892891 logger .info ("==============================================\n " )
@@ -1038,14 +1037,15 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
10381037 # note: here we dont use step since wandb keeps only the latest value for each step
10391038 self .run .log (log_data )
10401039
1041- async def log_samples (self , samples : Dict [ str , List [dict ] ], step : int ) -> None :
1040+ async def log_samples (self , samples : List [Metric ], step : int ) -> None :
10421041 """Log sample-level data incrementally to persistent WandB Tables."""
10431042 import wandb
10441043
10451044 if not self .run :
10461045 return
10471046
1048- for table_name , table_rows in samples .items ():
1047+ for sample in samples :
1048+ table_name , table_rows = sample .key , sample .value
10491049 if not table_rows :
10501050 continue
10511051
0 commit comments