@@ -201,17 +201,6 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri
201201 return reduced_metrics
202202
203203
204- def record_episode_sample (table_name : str , episode ):
205- """
206- Record a structured sample-level log for a single episode.
207- Args:
208- table_name (str): logging prefix (e.g. "rollout/sample").
209- episode (Episode): episode object with filled attributes.
210- """
211- sample = episode .to_dict (exclude = ["ref_logprobs" , "completion" ])
212- record_metric (table_name , sample , Reduce .SAMPLE )
213-
214-
215204################
216205# Accumulators #
217206################
@@ -430,7 +419,7 @@ class SampleAccumulator(MetricAccumulator):
430419 """
431420
432421 def __init__ (
433- self , reduction : Reduce , top_k : int = 1 , bottom_k : int = 1 , key : str = "reward "
422+ self , reduction : Reduce , top_k : int = 1 , bottom_k : int = 1 , key : str = "score "
434423 ):
435424 super ().__init__ (reduction )
436425 self .samples : List [Dict [str , Any ]] = []
@@ -869,12 +858,10 @@ async def finish(self) -> None:
869858 async def log_samples (self , samples : List [Metric ], step : int ) -> None :
870859 """Pretty-print sample-level logs to console."""
871860
872- logger .info (f"========== SAMPLE LOGS STEP { step } ==========" )
873861 for sample in samples :
874862 table_name , table_rows = sample .key , sample .value
875863 logger .info (f"[{ table_name } ] ({ len (table_rows )} samples)" )
876864 logger .info (json .dumps (table_rows , indent = 2 , ensure_ascii = False ))
877- logger .info ("==============================================\n " )
878865
879866
880867class WandbBackend (LoggerBackend ):
@@ -1056,6 +1043,13 @@ async def log_samples(self, samples: List[Metric], step: int) -> None:
10561043
10571044 # Add rows (fill missing columns with None)
10581045 for s in table_rows :
1046+ # Check for extra columns not in the table schema
1047+ extra_columns = set (s .keys ()) - set (table .columns )
1048+ if extra_columns :
1049+ logger .warning (
1050+ f"WandbBackend: Row has extra columns not in table '{ table_name } ': { sorted (extra_columns )} . "
1051+ f"These will be ignored."
1052+ )
10591053 values = [s .get (c ) for c in table .columns ]
10601054 table .add_data (* values )
10611055
0 commit comments