@@ -467,7 +467,7 @@ def compute_mismatch_metrics(
467467 - When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
468468 """
469469 metrics = {}
470-
470+ metrics [ "valid" ] = response_mask . any (). item ()
471471 # 1. Training policy perplexity (always available)
472472 # Formula: exp(-1/|T| * Σ log π_training(y_t|y_<t))
473473 # where |T| is the number of tokens generated by the model
@@ -529,18 +529,32 @@ def compute_mismatch_metrics(
529529
530530def merge_rollout_is_metrics (rollout_is_metrics : list [dict [str , float ]], device = "cuda" ) -> dict [str , float ]:
531531 metrics = {}
532- for key in rollout_is_metrics [0 ].keys ():
533- all_values = [m [key ] for m in rollout_is_metrics ]
532+ keys = [k for k in rollout_is_metrics [0 ].keys () if k != "mismatch/valid" ]
533+
534+ for key in keys :
535+ values = []
536+ valids = []
537+ for m in rollout_is_metrics :
538+ is_valid = m .get ("mismatch/valid" , True )
539+ valids .append (float (is_valid ))
540+ values .append (m [key ] if is_valid else 0.0 ) # set to 0.0 if invalid
541+ value_tensor = torch .tensor (values , dtype = torch .float32 , device = device )
542+ valid_tensor = torch .tensor (valids , dtype = torch .float32 , device = device )
543+
544+ # Aggregate across all processes
534545 if "max" in key :
535- max_value = torch . tensor ( all_values ). max (). to ( torch . float32 ). to ( device )
546+ max_value = value_tensor . max ()
536547 dist .all_reduce (max_value , op = dist .ReduceOp .MAX )
537548 metrics [key ] = max_value .item ()
538549 elif "min" in key :
539- min_value = torch . tensor ( all_values ). min (). to ( torch . float32 ). to ( device )
550+ min_value = value_tensor . min ()
540551 dist .all_reduce (min_value , op = dist .ReduceOp .MIN )
541552 metrics [key ] = min_value .item ()
542553 else :
543- mean_value = torch .tensor (all_values ).mean ().to (torch .float32 ).to (device )
544- dist .all_reduce (mean_value , op = dist .ReduceOp .AVG )
545- metrics [key ] = mean_value .item ()
554+ sum_value = value_tensor .sum ()
555+ count_value = valid_tensor .sum ()
556+ dist .all_reduce (sum_value , op = dist .ReduceOp .SUM )
557+ dist .all_reduce (count_value , op = dist .ReduceOp .SUM )
558+ metrics [key ] = sum_value .item () / count_value .item () if count_value .item () > 0 else 0.0
559+
546560 return metrics
0 commit comments