@@ -228,7 +228,15 @@ def record_episode_sample(table_name: str, episode):
228228 "pad_id" : episode .pad_id ,
229229 }
230230
231+ print (
232+ "[DEBUG] Adding sample to table via record_metric, episode_id: " ,
233+ episode .episode_id ,
234+ )
231235 record_metric (table_name , sample , Reduce .SAMPLE )
236+ print (
237+ "[DEBUG] Added sample to table via record_metric, episode_id: " ,
238+ episode .episode_id ,
239+ )
232240
233241
234242#################
@@ -499,19 +507,22 @@ def __init__(self, reduction: Reduce):
499507 super ().__init__ (reduction )
500508 self .samples : List [Dict [str , Any ]] = []
501509 self .filter = TopBottomKFilter ()
510+ self .is_reset = True
502511
503512 def append (self , value : dict ) -> None :
504513 if not isinstance (value , dict ):
505514 raise ValueError (f"Expected dict, got { type (value )} " )
506515
516+ self .is_reset = False
507517 # Only keep the sample if filter_append returns True
508518 if self .filter .filter_append (value ):
509519 self .samples .append (value )
510520
511521 def get_value (self ) -> list [dict ]:
512522 """Return locally collected (and optionally filtered) samples."""
513523 # Apply flush-time filter (e.g. heap selection, threshold trimming)
514- return self .filter .filter_flush (self .samples )
524+ results = self .filter .filter_flush (self .samples )
525+ return results
515526
516527 def get_state (self ) -> Dict [str , Any ]:
517528 """Serialize accumulator state for cross-rank reduction."""
@@ -530,6 +541,7 @@ def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> list[dic
530541
531542 def reset (self ) -> None :
532543 """Clear local samples and reset filter state."""
544+ self .is_reset = True
533545 self .samples .clear ()
534546 self .filter .reset ()
535547
@@ -701,12 +713,12 @@ def push(self, metric: Metric) -> None:
701713
702714 # For PER_RANK_NO_REDUCE backends: stream without reduce
703715 for backend in self .per_rank_no_reduce_backends :
704- if metric .reduction == Reduce .SAMPLE :
705- # Wrap singleton Metric into expected {key: [list_of_dicts]} format
706- sample = {metric .key : [metric .value ]}
707- asyncio .create_task (backend .log_samples (sample , self .global_step ))
708- else :
709- backend .log_stream (metric = metric , global_step = self .global_step )
716+ # if metric.reduction == Reduce.SAMPLE:
717+ # # Wrap singleton Metric into expected {key: [list_of_dicts]} format
718+ # sample = {metric.key: [metric.value]}
719+ # asyncio.create_task(backend.log_samples(sample, self.global_step))
720+ # else:
721+ backend .log_stream (metric = metric , global_step = self .global_step )
710722
711723 # Always accumulate for reduction and state return
712724 key = metric .key
@@ -773,6 +785,7 @@ async def flush(
773785
774786 for backend in self .per_rank_reduce_backends :
775787 if scalar_metrics :
788+ print (f"[DEBUG] calling log_batch from MetricCollector" )
776789 await backend .log_batch (scalar_metrics , global_step )
777790 if sample_metrics :
778791 await backend .log_samples (sample_metrics , global_step )
@@ -895,6 +908,7 @@ async def init(
895908 async def log_batch (
896909 self , metrics : list [Metric ], global_step : int , * args , ** kwargs
897910 ) -> None :
911+ print (f"[DEBUG] calling log_batch with { len (metrics )} metrics" )
898912 metrics_str = "\n " .join (
899913 f" { metric .key } : { metric .value } "
900914 for metric in sorted (metrics , key = lambda m : m .key )
@@ -913,6 +927,8 @@ async def log_samples(self, samples: Dict[str, List[dict]], step: int) -> None:
913927 """Pretty-print sample-level logs to console."""
914928 import json
915929
930+ print (f"[DEBUG] calling log_samples with { len (samples )} samples" )
931+
916932 logger .info (f"========== SAMPLE LOGS STEP { step } ==========" )
917933 for table_name , table_rows in samples .items ():
918934 logger .info (f"[{ table_name } ] ({ len (table_rows )} samples)" )
0 commit comments