Skip to content

Commit 7124e44

Browse files
authored
feat: enhance advantages tracking and normalization stability in GRPO (#1423)
Signed-off-by: Felipe Vieira Frujeri <ffrujeri@nvidia.com>
1 parent 779f775 commit 7124e44

File tree

4 files changed

+347
-9
lines changed

4 files changed

+347
-9
lines changed

nemo_rl/algorithms/grpo.py

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,25 @@ def setup(
538538
# ===============================================================================
539539

540540

541+
def normalize_advantages_with_epsilon(
542+
advantages: torch.Tensor,
543+
std: torch.Tensor,
544+
epsilon: float = 1e-6,
545+
) -> torch.Tensor:
546+
"""Normalize advantages by standard deviation with epsilon to avoid division by zero.
547+
548+
Args:
549+
advantages: Tensor of shape (batch_size, 1) containing advantage values
550+
std: Tensor of shape (batch_size,) containing standard deviation values
551+
epsilon: Small value to avoid division by zero, defaults to 1e-6
552+
553+
Returns:
554+
Normalized advantages tensor of same shape as input advantages
555+
"""
556+
# Use epsilon to avoid division by zero instead of masking
557+
return advantages / (std.unsqueeze(-1) + epsilon)
558+
559+
541560
def dynamic_sampling(
542561
repeated_batch: BatchedDataDict[DatumSpec],
543562
std: torch.Tensor,
@@ -1056,10 +1075,9 @@ def grpo_train(
10561075
advantages = (rewards - baseline).unsqueeze(-1)
10571076

10581077
if master_config["grpo"]["normalize_rewards"]:
1059-
# don't sharpen the ones with no variation
1060-
zero_std_mask = std > 0
1061-
advantages[zero_std_mask] = (
1062-
advantages[zero_std_mask] / std.unsqueeze(-1)[zero_std_mask]
1078+
advantages = normalize_advantages_with_epsilon(
1079+
advantages=advantages,
1080+
std=std,
10631081
)
10641082

10651083
with timer.time("data_processing"):
@@ -1172,12 +1190,31 @@ def grpo_train(
11721190
val_metrics, total_steps + 1, prefix="validation"
11731191
)
11741192

1193+
# Get flat advantages and token mask for masked metrics computation
1194+
flat_advantages = flat_messages["advantages"]
1195+
flat_token_mask = flat_messages["token_loss_mask"]
1196+
1197+
# Filter advantages using token mask (only valid response tokens)
1198+
response_advantages = torch.masked_select(
1199+
flat_advantages, flat_token_mask.bool()
1200+
)
1201+
11751202
metrics = {
11761203
"loss": train_results["loss"].numpy(),
11771204
"grad_norm": train_results["grad_norm"].numpy(),
11781205
"reward": rewards.numpy(),
11791206
"mean_prompt_length": repeated_batch["length"].numpy(),
11801207
"total_num_tokens": input_lengths.numpy(),
1208+
# Add masked advantages tracking metrics (only for valid response tokens)
1209+
"advantages/mean": torch.mean(response_advantages).detach().item()
1210+
if response_advantages.numel() > 0
1211+
else 0.0,
1212+
"advantages/max": torch.max(response_advantages).detach().item()
1213+
if response_advantages.numel() > 0
1214+
else 0.0,
1215+
"advantages/min": torch.min(response_advantages).detach().item()
1216+
if response_advantages.numel() > 0
1217+
else 0.0,
11811218
**ds_metrics,
11821219
}
11831220
if master_config["grpo"]["use_dynamic_sampling"]:
@@ -1929,10 +1966,11 @@ def async_grpo_train(
19291966
)
19301967

19311968
if master_config["grpo"]["normalize_rewards"]:
1932-
zero_std_mask = std > 0
1933-
advantages[zero_std_mask] = (
1934-
advantages[zero_std_mask] / std.unsqueeze(-1)[zero_std_mask]
1969+
advantages = normalize_advantages_with_epsilon(
1970+
advantages=advantages,
1971+
std=std,
19351972
)
1973+
19361974
print(
19371975
f" 📊 Normalized advantages stats: min={advantages.min():.4f}, max={advantages.max():.4f}, mean={advantages.mean():.4f}, std={advantages.std():.4f}"
19381976
)
@@ -2060,12 +2098,31 @@ def async_grpo_train(
20602098

20612099
# Resume trajectory collection after validation
20622100
trajectory_collector.resume.remote()
2101+
# Get flat advantages and token mask for masked metrics computation
2102+
flat_advantages = flat_messages["advantages"]
2103+
flat_token_mask = flat_messages["token_loss_mask"]
2104+
2105+
# Filter advantages using token mask (only valid response tokens)
2106+
response_advantages = torch.masked_select(
2107+
flat_advantages, flat_token_mask.bool()
2108+
)
2109+
20632110
metrics = {
20642111
"loss": train_results["loss"].numpy(),
20652112
"reward": rewards.numpy(),
20662113
"grad_norm": train_results["grad_norm"].numpy(),
20672114
"mean_prompt_length": repeated_batch["length"].numpy(),
20682115
"total_num_tokens": input_lengths.numpy(),
2116+
# Add masked advantages tracking metrics (only for valid response tokens)
2117+
"advantages/mean": torch.mean(response_advantages).detach().item()
2118+
if response_advantages.numel() > 0
2119+
else 0.0,
2120+
"advantages/max": torch.max(response_advantages).detach().item()
2121+
if response_advantages.numel() > 0
2122+
else 0.0,
2123+
"advantages/min": torch.min(response_advantages).detach().item()
2124+
if response_advantages.numel() > 0
2125+
else 0.0,
20692126
}
20702127
metrics.update(train_results["all_mb_metrics"])
20712128
for k, v in metrics.items():

nemo_rl/algorithms/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,12 @@ def calculate_baseline_and_std_per_prompt(
9999

100100
baseline = torch.zeros_like(rewards)
101101
sq_baseline = torch.zeros_like(rewards)
102+
std = torch.zeros_like(rewards)
102103
device_ordinal = rewards.get_device()
103104
if device_ordinal == -1:
104105
reward_device = torch.device("cpu")
105106
else:
106-
reward_device = torch.device(reward_device)
107+
reward_device = torch.device(f"cuda:{device_ordinal}")
107108

108109
for i in range(len(unique_prompts)):
109110
is_matching_prompt = (prompts == unique_prompts[i]).all(1)
@@ -142,8 +143,15 @@ def calculate_baseline_and_std_per_prompt(
142143

143144
baseline[prompt_idx] = prompt_baseline
144145
sq_baseline[prompt_idx] = prompt_baseline_square
146+
std[prompt_idx] = (
147+
(
148+
(prompt_baseline_square - prompt_baseline.square())
149+
* (num_valid / (num_valid - 1))
150+
)
151+
.sqrt()
152+
.nan_to_num(0)
153+
)
145154

146-
std = (sq_baseline - baseline.square()).sqrt().nan_to_num(0)
147155
return baseline, std
148156

149157

tests/unit/algorithms/test_grpo.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
async_grpo_train,
2525
dynamic_sampling,
2626
grpo_train,
27+
normalize_advantages_with_epsilon,
2728
)
2829
from nemo_rl.algorithms.loss_functions import ClippedPGLossFn
2930
from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType
@@ -1208,3 +1209,75 @@ def test_grpo_exit_on_timeout(mock_grpo_components, train_func, capsys):
12081209
assert not (line.startswith("Step ") and "Step 9" in line), (
12091210
f"Training continued to next step after timeout: {line}"
12101211
)
1212+
1213+
1214+
# ============================================================================
1215+
# Tests for normalize_advantages_with_epsilon function
1216+
# ============================================================================
1217+
1218+
1219+
def test_normalize_advantages_with_epsilon_basic():
1220+
"""Test basic functionality of normalize_advantages_with_epsilon."""
1221+
# Test case with normal values
1222+
advantages = torch.tensor([[2.0], [4.0], [6.0]])
1223+
std = torch.tensor([1.0, 2.0, 3.0])
1224+
epsilon = 1e-6
1225+
1226+
result = normalize_advantages_with_epsilon(advantages, std, epsilon)
1227+
1228+
expected = torch.tensor([[2.0], [2.0], [2.0]])
1229+
assert torch.allclose(result, expected, rtol=1e-5)
1230+
1231+
1232+
def test_normalize_advantages_with_epsilon_zero_std():
1233+
"""Test normalize_advantages_with_epsilon when std contains zeros."""
1234+
advantages = torch.tensor([[1.0], [2.0], [3.0]])
1235+
std = torch.tensor([0.0, 1.0, 0.0]) # Zero std for indices 0 and 2
1236+
epsilon = 1e-6
1237+
1238+
result = normalize_advantages_with_epsilon(advantages, std, epsilon)
1239+
1240+
# When std=0, result should be advantages / epsilon
1241+
expected = torch.tensor([[1.0 / epsilon], [2.0], [3.0 / epsilon]])
1242+
assert torch.allclose(result, expected, rtol=1e-5)
1243+
1244+
1245+
def test_normalize_advantages_with_epsilon_all_zero_std():
1246+
"""Test normalize_advantages_with_epsilon when all std values are zero."""
1247+
advantages = torch.tensor([[1.5], [2.5], [3.5]])
1248+
std = torch.tensor([0.0, 0.0, 0.0])
1249+
epsilon = 1e-8
1250+
1251+
result = normalize_advantages_with_epsilon(advantages, std, epsilon)
1252+
1253+
expected = advantages / epsilon
1254+
assert torch.allclose(result, expected, rtol=1e-5)
1255+
1256+
1257+
def test_normalize_advantages_with_epsilon_tensor_shapes():
1258+
"""Test normalize_advantages_with_epsilon with different tensor shapes."""
1259+
# Test with batch size 1
1260+
advantages = torch.tensor([[5.0]])
1261+
std = torch.tensor([2.0])
1262+
result = normalize_advantages_with_epsilon(advantages, std)
1263+
expected = torch.tensor([[2.5]])
1264+
assert torch.allclose(result, expected, rtol=1e-5)
1265+
1266+
# Test with larger batch
1267+
batch_size = 10
1268+
advantages = torch.ones(batch_size, 1) * 3.0
1269+
std = torch.ones(batch_size) * 1.5
1270+
result = normalize_advantages_with_epsilon(advantages, std)
1271+
expected = torch.ones(batch_size, 1) * 2.0
1272+
assert torch.allclose(result, expected, rtol=1e-5)
1273+
1274+
1275+
def test_normalize_advantages_with_epsilon_negative_advantages():
1276+
"""Test normalize_advantages_with_epsilon with negative advantages."""
1277+
advantages = torch.tensor([[-2.0], [3.0], [-1.5]])
1278+
std = torch.tensor([1.0, 1.5, 0.5])
1279+
1280+
result = normalize_advantages_with_epsilon(advantages, std)
1281+
1282+
expected = torch.tensor([[-2.0], [2.0], [-3.0]])
1283+
assert torch.allclose(result, expected, rtol=1e-5)

0 commit comments

Comments
 (0)