@@ -113,6 +113,38 @@ def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None
113113 collector .push (key , value , reduction )
114114
115115
116+ def record_episode_sample (key : str , episode ):
117+ """
118+ Record a structured sample-level log for a single episode.
119+
120+ Args:
121+ key (str): logging prefix (e.g. "rollout/sample").
122+ episode (Episode): episode object with filled attributes.
123+ reward_breakdown (dict[str, float]): per-function rewards, e.g. {"MathReward": 0.8, "FormatReward": 1.0}.
124+ """
125+ sample = {
126+ "episode_id" : episode .episode_id ,
127+ "policy_version" : episode .policy_version ,
128+ "prompt" : episode .request ,
129+ "response" : episode .response ,
130+ "target" : episode .target ,
131+ ** (
132+ episode .reward_breakdown or {}
133+ ), # per-fn breakdown including the average reward
134+ "advantage" : episode .advantage ,
135+ "ref_logprobs" : (
136+ episode .ref_logprobs .mean ().item ()
137+ if episode .ref_logprobs is not None
138+ else None
139+ ),
140+ "request_len" : episode .request_len ,
141+ "response_len" : episode .response_len ,
142+ "pad_id" : episode .pad_id ,
143+ }
144+
145+ record_metric (key , sample , Reduce .SAMPLE )
146+
147+
116148def reduce_metrics_states (states : List [Dict [str , Dict [str , Any ]]]) -> Dict [str , Any ]:
117149 """Reduce metric accumulators states to a single value per metric.
118150
@@ -465,7 +497,9 @@ class SampleAccumulator(MetricAccumulator):
465497 Optionally uses a SampleFilter to decide what to keep at append/flush time.
466498 """
467499
468- def __init__ (self , reduction : Reduce , filter : SampleFilter | None = None ):
500+ def __init__ (
501+ self , reduction : Reduce , filter : SampleFilter | None = TopBottomKFilter ()
502+ ):
469503 super ().__init__ (reduction )
470504 self .samples : List [Dict [str , Any ]] = []
471505 self .filter = filter
@@ -598,6 +632,7 @@ def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None:
598632 raise ValueError ("Collector not initialized—call init first" )
599633
600634 if key not in self .accumulators :
635+ # TODO: make sample filter configurable
601636 self .accumulators [key ] = reduction .accumulator_class (reduction )
602637
603638 self .accumulators [key ].append (value )
@@ -724,6 +759,20 @@ async def log(self, metrics: Dict[str, Any], step: int) -> None:
724759 logger .info (f" { key } : { value } " )
725760 logger .info ("==============================\n " )
726761
762+ async def log_samples (self , samples : Dict [str , List [dict ]], step : int ) -> None :
763+ """Pretty-print sample-level logs to console."""
764+ if not samples :
765+ return
766+ import pprint
767+
768+ logger .info (f"=== [{ self .prefix } ] - SAMPLE LOGS STEP { step } ===" )
769+ for key , rows in samples .items ():
770+ logger .info (f"[{ key } ] ({ len (rows )} samples)" )
771+ for sample in rows :
772+ pretty = pprint .pformat (sample , indent = 4 , width = 120 , compact = True )
773+ logger .info (pretty )
774+ logger .info ("==============================================\n " )
775+
727776 async def finish (self ) -> None :
728777 pass
729778
@@ -836,6 +885,28 @@ async def log(self, metrics: Dict[str, Any], step: int) -> None:
836885 else :
837886 logger .debug (f"WandbBackend: No run started, skipping log for { self .name } " )
838887
888+ async def log_samples (self , samples : Dict [str , List [dict ]], step : int ) -> None :
889+ """Log sample-level data to WandB Tables."""
890+ import wandb
891+
892+ if not self .run or not samples :
893+ return
894+
895+ for key , rows in samples .items ():
896+ if not rows :
897+ continue
898+
899+ # Create a WandB Table dynamically based on keys of first sample
900+ columns = list (rows [0 ].keys ())
901+ table = wandb .Table (columns = columns )
902+ for sample in rows :
903+ table .add_data (* [sample .get (c ) for c in columns ])
904+
905+ self .run .log ({f"{ key } _table" : table , "global_step" : step })
906+ logger .info (
907+ f"WandbBackend: Logged { len (rows )} samples for { key } at step { step } "
908+ )
909+
839910 def get_metadata_for_secondary_ranks (self ) -> Dict [str , Any ]:
840911 if self .run and not self .reduce_across_ranks and self .share_run_id :
841912 return {"shared_run_id" : self .run .id }
0 commit comments