@@ -538,6 +538,25 @@ def setup(
538538# ===============================================================================
539539
540540
541+ def normalize_advantages_with_epsilon (
542+ advantages : torch .Tensor ,
543+ std : torch .Tensor ,
544+ epsilon : float = 1e-6 ,
545+ ) -> torch .Tensor :
546+ """Normalize advantages by standard deviation with epsilon to avoid division by zero.
547+
548+ Args:
549+ advantages: Tensor of shape (batch_size, 1) containing advantage values
550+ std: Tensor of shape (batch_size,) containing standard deviation values
551+ epsilon: Small value to avoid division by zero, defaults to 1e-6
552+
553+ Returns:
554+ Normalized advantages tensor of same shape as input advantages
555+ """
556+ # Use epsilon to avoid division by zero instead of masking
557+ return advantages / (std .unsqueeze (- 1 ) + epsilon )
558+
559+
541560def dynamic_sampling (
542561 repeated_batch : BatchedDataDict [DatumSpec ],
543562 std : torch .Tensor ,
@@ -1056,10 +1075,9 @@ def grpo_train(
10561075 advantages = (rewards - baseline ).unsqueeze (- 1 )
10571076
10581077 if master_config ["grpo" ]["normalize_rewards" ]:
1059- # don't sharpen the ones with no variation
1060- zero_std_mask = std > 0
1061- advantages [zero_std_mask ] = (
1062- advantages [zero_std_mask ] / std .unsqueeze (- 1 )[zero_std_mask ]
1078+ advantages = normalize_advantages_with_epsilon (
1079+ advantages = advantages ,
1080+ std = std ,
10631081 )
10641082
10651083 with timer .time ("data_processing" ):
@@ -1172,12 +1190,31 @@ def grpo_train(
11721190 val_metrics , total_steps + 1 , prefix = "validation"
11731191 )
11741192
1193+ # Get flat advantages and token mask for masked metrics computation
1194+ flat_advantages = flat_messages ["advantages" ]
1195+ flat_token_mask = flat_messages ["token_loss_mask" ]
1196+
1197+ # Filter advantages using token mask (only valid response tokens)
1198+ response_advantages = torch .masked_select (
1199+ flat_advantages , flat_token_mask .bool ()
1200+ )
1201+
11751202 metrics = {
11761203 "loss" : train_results ["loss" ].numpy (),
11771204 "grad_norm" : train_results ["grad_norm" ].numpy (),
11781205 "reward" : rewards .numpy (),
11791206 "mean_prompt_length" : repeated_batch ["length" ].numpy (),
11801207 "total_num_tokens" : input_lengths .numpy (),
1208+ # Add masked advantages tracking metrics (only for valid response tokens)
1209+ "advantages/mean" : torch .mean (response_advantages ).detach ().item ()
1210+ if response_advantages .numel () > 0
1211+ else 0.0 ,
1212+ "advantages/max" : torch .max (response_advantages ).detach ().item ()
1213+ if response_advantages .numel () > 0
1214+ else 0.0 ,
1215+ "advantages/min" : torch .min (response_advantages ).detach ().item ()
1216+ if response_advantages .numel () > 0
1217+ else 0.0 ,
11811218 ** ds_metrics ,
11821219 }
11831220 if master_config ["grpo" ]["use_dynamic_sampling" ]:
@@ -1929,10 +1966,11 @@ def async_grpo_train(
19291966 )
19301967
19311968 if master_config ["grpo" ]["normalize_rewards" ]:
1932- zero_std_mask = std > 0
1933- advantages [ zero_std_mask ] = (
1934- advantages [ zero_std_mask ] / std . unsqueeze ( - 1 )[ zero_std_mask ]
1969+ advantages = normalize_advantages_with_epsilon (
1970+ advantages = advantages ,
1971+ std = std ,
19351972 )
1973+
19361974 print (
19371975 f" 📊 Normalized advantages stats: min={ advantages .min ():.4f} , max={ advantages .max ():.4f} , mean={ advantages .mean ():.4f} , std={ advantages .std ():.4f} "
19381976 )
@@ -2060,12 +2098,31 @@ def async_grpo_train(
20602098
20612099 # Resume trajectory collection after validation
20622100 trajectory_collector .resume .remote ()
2101+ # Get flat advantages and token mask for masked metrics computation
2102+ flat_advantages = flat_messages ["advantages" ]
2103+ flat_token_mask = flat_messages ["token_loss_mask" ]
2104+
2105+ # Filter advantages using token mask (only valid response tokens)
2106+ response_advantages = torch .masked_select (
2107+ flat_advantages , flat_token_mask .bool ()
2108+ )
2109+
20632110 metrics = {
20642111 "loss" : train_results ["loss" ].numpy (),
20652112 "reward" : rewards .numpy (),
20662113 "grad_norm" : train_results ["grad_norm" ].numpy (),
20672114 "mean_prompt_length" : repeated_batch ["length" ].numpy (),
20682115 "total_num_tokens" : input_lengths .numpy (),
2116+ # Add masked advantages tracking metrics (only for valid response tokens)
2117+ "advantages/mean" : torch .mean (response_advantages ).detach ().item ()
2118+ if response_advantages .numel () > 0
2119+ else 0.0 ,
2120+ "advantages/max" : torch .max (response_advantages ).detach ().item ()
2121+ if response_advantages .numel () > 0
2122+ else 0.0 ,
2123+ "advantages/min" : torch .min (response_advantages ).detach ().item ()
2124+ if response_advantages .numel () > 0
2125+ else 0.0 ,
20692126 }
20702127 metrics .update (train_results ["all_mb_metrics" ])
20712128 for k , v in metrics .items ():
0 commit comments