Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,18 +638,26 @@ def normalize_advantages_with_epsilon(
std: torch.Tensor,
epsilon: float = 1e-6,
) -> torch.Tensor:
"""Normalize advantages by standard deviation with epsilon to avoid division by zero.
"""Normalize advantages by standard deviation, skipping samples with zero std.

When std is exactly zero (from leave-one-out baseline with identical rewards),
normalization is skipped for those samples to prevent numerical instability.
This makes normalize_rewards compatible with use_leave_one_out_baseline.

Args:
advantages: Tensor of shape (batch_size, 1) containing advantage values
std: Tensor of shape (batch_size,) containing standard deviation values
epsilon: Small value to avoid division by zero, defaults to 1e-6
epsilon: Small value to avoid division by very small std, defaults to 1e-6

Returns:
Normalized advantages tensor of same shape as input advantages
"""
# Use epsilon to avoid division by zero instead of masking
return advantages / (std.unsqueeze(-1) + epsilon)
# Only normalize where std > 0 to avoid division by near-zero
non_zero_std_mask = std > 0
advantages[non_zero_std_mask] = advantages[non_zero_std_mask] / (
std.unsqueeze(-1)[non_zero_std_mask] + epsilon
)
return advantages


def dynamic_sampling(
Expand Down
71 changes: 68 additions & 3 deletions tests/unit/algorithms/test_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,8 +1237,11 @@ def test_normalize_advantages_with_epsilon_zero_std():

result = normalize_advantages_with_epsilon(advantages, std, epsilon)

# When std=0, result should be advantages / epsilon
expected = torch.tensor([[1.0 / epsilon], [2.0], [3.0 / epsilon]])
# When std=0 AND advantage!=0, normalization is skipped (advantages unchanged)
# When std>0, normal normalization occurs
expected = torch.tensor(
[[1.0], [2.0], [3.0]]
) # Samples 0,2 unchanged; sample 1 normalized
assert torch.allclose(result, expected, rtol=1e-5)


Expand All @@ -1248,9 +1251,12 @@ def test_normalize_advantages_with_epsilon_all_zero_std():
std = torch.tensor([0.0, 0.0, 0.0])
epsilon = 1e-8

# Save expected values BEFORE calling function (since it modifies in-place)
expected = advantages.clone()

result = normalize_advantages_with_epsilon(advantages, std, epsilon)

expected = advantages / epsilon
# When std=0 AND advantage!=0, normalization is skipped (all unchanged)
assert torch.allclose(result, expected, rtol=1e-5)


Expand Down Expand Up @@ -1281,3 +1287,62 @@ def test_normalize_advantages_with_epsilon_negative_advantages():

expected = torch.tensor([[-2.0], [2.0], [-3.0]])
assert torch.allclose(result, expected, rtol=1e-5)


def test_normalize_advantages_with_zero_std_from_leave_one_out():
"""Test that zero std (from leave-one-out baseline) is handled gracefully by skipping normalization."""
# Simulate the leave-one-out case: rewards [1.0, 0.0, 0.0, 0.0]
# Sample 0 has baseline from [0, 0, 0] -> std=0, advantage=1.0
# Samples 1-3 have baseline from [1, 0, 0] -> std≈0.577, advantage≈-0.333
advantages = torch.tensor([[1.0], [-0.333], [-0.333], [-0.333]])
std = torch.tensor([0.0, 0.577, 0.577, 0.577])
epsilon = 1e-6

# Compute expected values BEFORE calling function (since it modifies in-place)
expected_sample_0 = advantages[0].clone()
expected_normalized = advantages[1:].clone() / (std[1:].unsqueeze(-1) + epsilon)

result = normalize_advantages_with_epsilon(advantages, std, epsilon)

# Sample 0: std=0 -> advantage unchanged (skip normalization)
assert torch.allclose(result[0], expected_sample_0, rtol=1e-5)

# Samples 1-3: std>0 -> normalized with epsilon
assert torch.allclose(result[1:], expected_normalized, rtol=1e-5)


def test_normalize_advantages_with_zero_std_and_zero_advantage():
"""Test that zero std with zero advantage is left unchanged."""
advantages = torch.tensor([[0.0], [1.0], [0.0]])
std = torch.tensor([0.0, 0.0, 1.0])
epsilon = 1e-6

# Compute expected values BEFORE calling function (since it modifies in-place)
expected_sample_0 = advantages[0].clone()
expected_sample_1 = advantages[1].clone()
expected_sample_2 = advantages[2].clone() / (std[2] + epsilon)

result = normalize_advantages_with_epsilon(advantages, std, epsilon)

# Sample 0: std=0, advantage=0 -> unchanged (skip normalization)
assert torch.allclose(result[0], expected_sample_0, rtol=1e-5)

# Sample 1: std=0, advantage!=0 -> unchanged (skip normalization)
assert torch.allclose(result[1], expected_sample_1, rtol=1e-5)

# Sample 2: std>0 -> normalize with epsilon
assert torch.allclose(result[2], expected_sample_2, rtol=1e-5)


def test_normalize_advantages_with_small_nonzero_std():
"""Test that small but non-zero std values still get normalized (no threshold)."""
advantages = torch.tensor([[2.0], [3.0], [-1.0]])
std = torch.tensor([0.001, 0.01, 0.0001]) # All small but non-zero

# Compute expected values BEFORE calling function (since it modifies in-place)
expected = advantages.clone() / (std.unsqueeze(-1) + 1e-6)

result = normalize_advantages_with_epsilon(advantages, std)

# All should be normalized since std > 0
assert torch.allclose(result, expected, rtol=1e-5)
Loading