@@ -139,12 +139,32 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri
139139 list[Metric]: List of reduced metrics
140140
141141 Example:
142- states = [
143- {"loss": {"count": 5, "sum": 14, "reduction_type": Reduce.MEAN}},
144- {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}},
145- ]
146- reduce_metrics_states(states)
147- >>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)]
142+ >>> states = [
143+ ... {
144+ ... "loss": {"count": 5, "sum": 14, "reduction_type": "mean"},
145+ ... "reward/sample": {
146+ ... "reduction_type": "sample",
147+ ... "samples": [{"episode_id": 1, "reward": 0.5}],
148+ ... },
149+ ... },
150+ ... {
151+ ... "loss": {"count": 10, "sum": 16, "reduction_type": "mean"},
152+ ... "reward/sample": {
153+ ... "reduction_type": "sample",
154+ ... "samples": [{"episode_id": 2, "reward": 1.0}],
155+ ... },
156+ ... },
157+ ... ]
158+ >>> metrics = reduce_metrics_states(states)
159+ >>> for m in metrics:
160+ ... print(m)
161+ Metric(key='loss', value=2.0, reduction=Reduce.MEAN)
162+ Metric(
163+ key='reward/sample',
164+ value=[{'episode_id': 1, 'reward': 0.5},
165+ {'episode_id': 2, 'reward': 1.0}],
166+ reduction=Reduce.SAMPLE,
167+ )
148168
149169 Raises:
150170 ValueError: on mismatched reduction types for the same metric key.
@@ -186,6 +206,31 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri
186206 return reduced_metrics
187207
188208
209+ def record_episode_sample (table_name : str , episode ):
210+ """
211+ Record a structured sample-level log for a single episode.
212+ Args:
213+ table_name (str): logging prefix (e.g. "rollout/sample").
214+ episode (Episode): episode object with filled attributes.
215+ """
216+ sample = {
217+ "episode_id" : episode .episode_id ,
218+ "policy_version" : episode .policy_version ,
219+ "prompt" : episode .request ,
220+ "response" : episode .response ,
221+ "target" : str (episode .target ),
222+ ** (
223+ episode .reward_breakdown or {}
224+ ), # per-fn breakdown including the average reward
225+ "advantage" : episode .advantage ,
226+ "request_len" : episode .request_len ,
227+ "response_len" : episode .response_len ,
228+ "pad_id" : episode .pad_id ,
229+ }
230+
231+ record_metric (table_name , sample , Reduce .SAMPLE )
232+
233+
189234#################
190235# SampleFilters #
191236#################
@@ -656,7 +701,12 @@ def push(self, metric: Metric) -> None:
656701
657702 # For PER_RANK_NO_REDUCE backends: stream without reduce
658703 for backend in self .per_rank_no_reduce_backends :
659- backend .log_stream (metric = metric , global_step = self .global_step )
704+ if metric .reduction == Reduce .SAMPLE :
705+ # Wrap singleton Metric into expected {key: [list_of_dicts]} format
706+ sample = {metric .key : [metric .value ]}
707+ asyncio .create_task (backend .log_samples (sample , self .global_step ))
708+ else :
709+ backend .log_stream (metric = metric , global_step = self .global_step )
660710
661711 # Always accumulate for reduction and state return
662712 key = metric .key
@@ -711,8 +761,21 @@ async def flush(
711761 if self .per_rank_reduce_backends :
712762 metrics_for_backends = reduce_metrics_states ([states ])
713763
764+ # Split into scalar metrics and sample metrics
765+ scalar_metrics = [
766+ m for m in metrics_for_backends if m .reduction != Reduce .SAMPLE
767+ ]
768+ sample_metrics = {
769+ m .key : m .value
770+ for m in metrics_for_backends
771+ if m .reduction == Reduce .SAMPLE
772+ }
773+
714774 for backend in self .per_rank_reduce_backends :
715- await backend .log_batch (metrics_for_backends , global_step )
775+ if scalar_metrics :
776+ await backend .log_batch (scalar_metrics , global_step )
777+ if sample_metrics :
778+ await backend .log_samples (sample_metrics , global_step )
716779
717780 # Update step counter for streaming backends
718781 # Note: This is incremented AFTER flush completes, so metrics recorded between
@@ -846,6 +909,16 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
846909 async def finish (self ) -> None :
847910 pass
848911
912+ async def log_samples (self , samples : Dict [str , List [dict ]], step : int ) -> None :
913+ """Pretty-print sample-level logs to console."""
914+ import json
915+
916+ logger .info (f"========== SAMPLE LOGS STEP { step } ==========" )
917+ for table_name , table_rows in samples .items ():
918+ logger .info (f"[{ table_name } ] ({ len (table_rows )} samples)" )
919+ logger .info (json .dumps (table_rows , indent = 2 , ensure_ascii = False ))
920+ logger .info ("==============================================\n " )
921+
849922
850923class WandbBackend (LoggerBackend ):
851924 """
@@ -882,6 +955,7 @@ def __init__(
882955 )
883956 self .run = None
884957 self .process_name = None
958+ self ._tables : dict [str , "wandb.Table" ] = {}
885959
886960 async def init (
887961 self ,
@@ -992,13 +1066,58 @@ def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None:
9921066 # note: here we dont use step since wandb keeps only the latest value for each step
9931067 self .run .log (log_data )
9941068
1069+ async def log_samples (self , samples : Dict [str , List [dict ]], step : int ) -> None :
1070+ """Log sample-level data incrementally to persistent WandB Tables."""
1071+ import wandb
1072+
1073+ if not self .run :
1074+ return
1075+
1076+ for table_name , table_rows in samples .items ():
1077+ if not table_rows :
1078+ continue
1079+
1080+ # If table doesn't exist yet, create it in INCREMENTAL mode
1081+ if table_name not in self ._tables :
1082+ columns = list (table_rows [0 ].keys ())
1083+ table = wandb .Table (columns = columns , log_mode = "INCREMENTAL" )
1084+ self ._tables [table_name ] = table
1085+ logger .info (
1086+ f"WandbBackend: Created new incremental table: { table_name } "
1087+ )
1088+ else :
1089+ table = self ._tables [table_name ]
1090+
1091+ # Add rows (fill missing columns with None)
1092+ for s in table_rows :
1093+ values = [s .get (c ) for c in table .columns ]
1094+ table .add_data (* values )
1095+
1096+ # Log the same table object (INCREMENTAL update)
1097+ self .run .log ({f"{ table_name } _table" : table })
1098+ logger .info (
1099+ f"WandbBackend: Appended { len (table_rows )} rows to incremental table '{ table_name } ' at step { step } "
1100+ )
1101+
9951102 def get_metadata_for_secondary_ranks (self ) -> dict [str , Any ]:
9961103 if self .run and self .per_rank_share_run :
9971104 return {"shared_run_id" : self .run .id }
9981105 return {}
9991106
10001107 async def finish (self ) -> None :
1108+ import wandb
1109+
10011110 if self .run :
1111+ # Convert each incremental table to immutable before finishing
1112+ for table_name , incr_table in self ._tables .items ():
1113+ final_table = wandb .Table (
1114+ columns = incr_table .columns ,
1115+ data = incr_table .data ,
1116+ log_mode = "IMMUTABLE" ,
1117+ )
1118+ self .run .log ({table_name : final_table })
1119+ logger .info (f"WandbBackend: Finalized table { table_name } " )
1120+
10021121 self .run .finish ()
10031122 logger .info (f"WandbBackend { self .process_name } : Finished run" )
10041123
0 commit comments