@@ -143,11 +143,6 @@ def record_episode_sample(key: str, episode):
143143 episode .reward_breakdown or {}
144144 ), # per-fn breakdown including the average reward
145145 "advantage" : episode .advantage ,
146- "ref_logprobs" : float (
147- episode .ref_logprobs .mean ().item ()
148- if episode .ref_logprobs is not None
149- else None
150- ),
151146 "request_len" : episode .request_len ,
152147 "response_len" : episode .response_len ,
153148 "pad_id" : episode .pad_id ,
@@ -850,16 +845,12 @@ def log_stream(self, metric: Metric, step: int, *args, **kwargs) -> None:
850845
851846 async def log_samples (self , samples : Dict [str , List [dict ]], step : int ) -> None :
852847 """Pretty-print sample-level logs to console."""
853- if not samples :
854- return
855848 import json
856849
857850 logger .info (f"========== SAMPLE LOGS STEP { step } ==========" )
858- for key , rows in samples .items ():
859- logger .info (f"[{ key } ] ({ len (rows )} samples)" )
860- for sample in rows :
861- pretty = json .dumps (sample , indent = 2 , ensure_ascii = False )
862- logger .info (pretty )
851+ for table_name , table_rows in samples .items ():
852+ logger .info (f"[{ table_name } ] ({ len (table_rows )} samples)" )
853+ logger .info (json .dumps (table_rows , indent = 2 , ensure_ascii = False ))
863854 logger .info ("==============================================\n " )
864855
865856 async def finish (self ) -> None :
@@ -999,24 +990,24 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
999990 if not self .run or not samples :
1000991 return
1001992
1002- for key , rows in samples .items ():
1003- if not rows :
993+ for table_name , table_rows in samples .items ():
994+ if not table_rows :
1004995 continue
1005996
1006997 # Use all keys to avoid dropped fields
1007- columns = sorted ({k for s in rows for k in s .keys ()})
998+ columns = sorted ({k for s in table_rows for k in s .keys ()})
1008999 table = wandb .Table (columns = columns )
10091000
1010- for s in rows :
1011- values = [s .get (c ) for c in columns ]
1001+ for s in table_rows :
1002+ values = [s .get (c ) for c in columns ] # returns None for missing keys
10121003 table .add_data (* values )
10131004
10141005 # Unique table name avoids overwrite; commit forces sync
1015- table_name = f"{ key } _table_step{ step } "
1016- self .run .log ({table_name : table , "_num_rows" : len (rows )}, commit = True )
1006+ table_name = f"{ table_name } _table_step{ step } "
1007+ self .run .log ({table_name : table , "_num_rows" : len (table_rows )}, commit = True )
10171008
10181009 logger .info (
1019- f"WandbBackend: Logged { len (rows )} samples for { key } at step { step } "
1010+ f"WandbBackend: Logged { len (table_rows )} samples for { table_name } at step { step } "
10201011 )
10211012
10221013 def get_metadata_for_secondary_ranks (self ) -> Dict [str , Any ]:
0 commit comments