@@ -195,7 +195,7 @@ def record_episode_sample(table_name: str, episode):
195195 table_name (str): logging prefix (e.g. "rollout/sample").
196196 episode (Episode): episode object with filled attributes.
197197 """
198- sample = episode .to_dict ()
198+ sample = episode .to_dict (exclude = [ "ref_logprobs" , "completion" ] )
199199 record_metric (table_name , sample , Reduce .SAMPLE )
200200
201201
@@ -675,9 +675,7 @@ def push(self, metric: Metric) -> None:
675675 for backend in self .per_rank_no_reduce_backends :
676676
677677 if metric .reduction == Reduce .SAMPLE :
678- # Wrap singleton Metric into expected {key: [list_of_dicts]} format
679- sample = {metric .key : [metric .value ]}
680- asyncio .create_task (backend .log_samples (sample , self .global_step ))
678+ asyncio .create_task (backend .log_samples ([metric ], self .global_step ))
681679 else :
682680 backend .log_stream (metric = metric , global_step = self .global_step )
683681
@@ -867,11 +865,12 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
867865 async def finish (self ) -> None :
868866 pass
869867
870- async def log_samples (self , samples : Dict [ str , List [dict ] ], step : int ) -> None :
868+ async def log_samples (self , samples : List [Metric ], step : int ) -> None :
871869 """Pretty-print sample-level logs to console."""
872870
873871 logger .info (f"========== SAMPLE LOGS STEP { step } ==========" )
874- for table_name , table_rows in samples .items ():
872+ for sample in samples :
873+ table_name , table_rows = sample .key , sample .value
875874 logger .info (f"[{ table_name } ] ({ len (table_rows )} samples)" )
876875 logger .info (json .dumps (table_rows , indent = 2 , ensure_ascii = False ))
877876 logger .info ("==============================================\n " )
@@ -1014,14 +1013,15 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
10141013 # note: here we dont use step since wandb keeps only the latest value for each step
10151014 self .run .log (log_data )
10161015
1017- async def log_samples (self , samples : Dict [ str , List [dict ] ], step : int ) -> None :
1016+ async def log_samples (self , samples : List [Metric ], step : int ) -> None :
10181017 """Log sample-level data incrementally to persistent WandB Tables."""
10191018 import wandb
10201019
10211020 if not self .run :
10221021 return
10231022
1024- for table_name , table_rows in samples .items ():
1023+ for sample in samples :
1024+ table_name , table_rows = sample .key , sample .value
10251025 if not table_rows :
10261026 continue
10271027
0 commit comments