Skip to content

Commit 198c991

Browse files
committed
add unit test for gdpo estimator and multi-reward env
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent e23fa7e commit 198c991

File tree

2 files changed

+160
-22
lines changed

2 files changed

+160
-22
lines changed

tests/unit/algorithms/test_grpo.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torchdata.stateful_dataloader import StatefulDataLoader
2121

2222
from nemo_rl.algorithms.advantage_estimator import (
23+
GDPOAdvantageEstimator,
2324
GRPOAdvantageEstimator,
2425
ReinforcePlusPlusAdvantageEstimator,
2526
)
@@ -1810,6 +1811,51 @@ def test_grpo_advantage_estimator_small_nonzero_std():
18101811
# ============================================================================
18111812

18121813

1814+
def test_gdpo_advantage_estimator_multiple_rewards():
1815+
"""Test GDPOAdvantageEstimator with multiple rewards."""
1816+
estimator_config = {
1817+
"use_leave_one_out_baseline": False,
1818+
"normalize_rewards": True,
1819+
}
1820+
loss_config = {}
1821+
estimator = GDPOAdvantageEstimator(estimator_config, loss_config)
1822+
1823+
prompt_ids = torch.tensor([[0], [0]])
1824+
mask = torch.ones(2, 3)
1825+
repeated_batch = {
1826+
"reward1": torch.tensor([1.0, 1.0]),
1827+
"reward2": torch.tensor([1.0, -1.0]),
1828+
"reward3": torch.tensor([1.0, 0.0]),
1829+
}
1830+
1831+
result = estimator.compute_advantage(prompt_ids, None, mask, repeated_batch)
1832+
assert result.shape == (2, 3)
1833+
assert torch.allclose(result[0, 0], torch.tensor(0.7071))
1834+
assert torch.allclose(result[1, 0], torch.tensor(-0.7071))
1835+
1836+
1837+
def test_gdpo_advantage_estimator_single_reward():
1838+
"""Test GDPOAdvantageEstimator with multiple rewards."""
1839+
estimator_config = {
1840+
"use_leave_one_out_baseline": False,
1841+
"normalize_rewards": True,
1842+
}
1843+
loss_config = {}
1844+
estimator = GDPOAdvantageEstimator(estimator_config, loss_config)
1845+
1846+
prompt_ids = torch.tensor([[0], [0]])
1847+
mask = torch.ones(2, 3)
1848+
repeated_batch = {"reward1": torch.tensor([1.0, 3.0])}
1849+
1850+
with pytest.raises(ValueError):
1851+
estimator.compute_advantage(prompt_ids, None, mask, repeated_batch)
1852+
1853+
1854+
# ============================================================================
1855+
# Tests for ReinforcePlusPlusAdvantageEstimator class
1856+
# ============================================================================
1857+
1858+
18131859
def test_reinforce_plus_plus_global_normalization():
18141860
"""Test that ReinforcePlusPlusAdvantageEstimator applies global normalization.
18151861

tests/unit/environments/test_math_environment.py

Lines changed: 114 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,34 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
1514
import time
1615

1716
import pytest
1817
import ray
1918

20-
from nemo_rl.distributed.ray_actor_environment_registry import (
21-
get_actor_python_env,
22-
)
23-
from nemo_rl.environments.math_environment import MathEnvironment
19+
from nemo_rl.environments.utils import create_env
20+
21+
# ============================================================================
22+
# Environment fixtures
23+
# ============================================================================
2424

2525

2626
@pytest.fixture(scope="module")
2727
def math_env():
2828
"""Create a MathEnvironment actor for testing."""
29-
env = MathEnvironment.options(
30-
runtime_env={
31-
"py_executable": get_actor_python_env(
32-
"nemo_rl.environments.math_environment.MathEnvironment"
33-
),
34-
"env_vars": dict(os.environ),
35-
}
36-
).remote({"num_workers": 2})
29+
env = create_env("math", {"num_workers": 2})
30+
yield env
31+
# Clean up the actor and wait for it to be killed
32+
env.shutdown.remote()
33+
ray.kill(env)
34+
# Give some time for cleanup
35+
time.sleep(0.1)
36+
37+
38+
@pytest.fixture(scope="module")
39+
def math_multi_reward_env():
40+
"""Create a MathMultiRewardEnvironment actor for testing."""
41+
env = create_env("math_multi_reward", {"num_workers": 2})
3742
yield env
3843
# Clean up the actor and wait for it to be killed
3944
env.shutdown.remote()
@@ -45,15 +50,7 @@ def math_env():
4550
@pytest.fixture(scope="module")
4651
def multichoice_env(request):
4752
"""Create a MathEnvironment actor for testing."""
48-
verifier_type = request.param
49-
env = MathEnvironment.options(
50-
runtime_env={
51-
"py_executable": get_actor_python_env(
52-
"nemo_rl.environments.math_environment.MathEnvironment"
53-
),
54-
"env_vars": dict(os.environ),
55-
}
56-
).remote({"num_workers": 2, "verifier_type": verifier_type})
53+
env = create_env("math", {"num_workers": 2, "verifier_type": request.param})
5754
yield env
5855
# Clean up the actor and wait for it to be killed
5956
env.shutdown.remote()
@@ -62,6 +59,11 @@ def multichoice_env(request):
6259
time.sleep(0.1)
6360

6461

62+
# ============================================================================
63+
# Data fixtures
64+
# ============================================================================
65+
66+
6567
@pytest.fixture
6668
def basic_test_data():
6769
"""Common test data for basic math problems."""
@@ -88,6 +90,41 @@ def basic_test_data():
8890
}
8991

9092

93+
@pytest.fixture
94+
def multi_reward_test_data():
95+
"""Common test data for basic math problems with multiple rewards."""
96+
return {
97+
"message_log_batch": [
98+
[
99+
{"role": "user", "content": "What is 2 + 2?"},
100+
{
101+
"role": "assistant",
102+
"content": "<think>2 + 2 = 4</think>\n<answer>4</answer>",
103+
},
104+
],
105+
[
106+
{"role": "user", "content": "What is 3 * 4?"},
107+
{
108+
"role": "assistant",
109+
"content": "<think>3 * 4 = 12</think>\n<answer>12.5</answer>",
110+
},
111+
],
112+
[
113+
{"role": "user", "content": "What is 10 - 5?"},
114+
{
115+
"role": "assistant",
116+
"content": "<think>10 - 5 = 5\n<answer>5</answer>",
117+
},
118+
],
119+
],
120+
"metadata": [
121+
{"ground_truth": "4"},
122+
{"ground_truth": "12"},
123+
{"ground_truth": "5"},
124+
],
125+
}
126+
127+
91128
@pytest.fixture
92129
def multichoice_test_data(request):
93130
"""Common test data for basic multichoice problems."""
@@ -170,6 +207,11 @@ def multiple_assistant_test_data():
170207
}
171208

172209

210+
# ============================================================================
211+
# Environment tests
212+
# ============================================================================
213+
214+
173215
def test_math_env_step_basic(math_env, basic_test_data):
174216
"""Test basic functionality of MathEnvironment step with simple messages."""
175217
result = ray.get(
@@ -204,6 +246,56 @@ def test_math_env_step_basic(math_env, basic_test_data):
204246
assert all(result.terminateds == 1.0), "All terminated flags should be 1.0"
205247

206248

249+
def test_multi_reward_env_step_basic(math_multi_reward_env, multi_reward_test_data):
250+
"""Test basic step: correct answer + valid format -> all 3 rewards 1.0."""
251+
result = ray.get(
252+
math_multi_reward_env.step.remote(
253+
multi_reward_test_data["message_log_batch"],
254+
multi_reward_test_data["metadata"],
255+
)
256+
)
257+
258+
# Check observations (based on correctness reward, index 0)
259+
assert len(result.observations) == 3, (
260+
"Should return observations for all 3 messages"
261+
)
262+
assert all(obs["role"] == "environment" for obs in result.observations), (
263+
"All observations should be from environment"
264+
)
265+
266+
# Check observations for each data point
267+
assert result.observations[0]["content"] == "Environment: correct"
268+
assert result.observations[1]["content"] == "Environment: incorrect"
269+
assert result.observations[2]["content"] == "Environment: correct"
270+
271+
# Check metadata
272+
assert len(result.metadata) == 3, "Should return metadata for all 3 messages"
273+
assert result.metadata == multi_reward_test_data["metadata"], (
274+
"Metadata should be unchanged"
275+
)
276+
277+
# Check rewards: shape (batch_size=3, number_of_rewards=3)
278+
assert result.rewards.shape == (3, 3), "Rewards should be a tensor of shape (3, 3)"
279+
280+
# Check rewards for each data point
281+
# First reward: correctness reward 1.0, int reward 1.0, format reward 1.0
282+
assert (result.rewards[0] == 1.0).all(), "First reward should be 1.0"
283+
# Second reward: correctness reward 0.0, int reward 0.0, format reward 1.0
284+
assert result.rewards[1][0] == 0.0
285+
assert result.rewards[1][1] == 0.0
286+
assert result.rewards[1][2] == 1.0
287+
# Third reward: correctness reward 1.0, int reward 1.0, format reward 0.0
288+
assert result.rewards[2][0] == 1.0
289+
assert result.rewards[2][1] == 1.0
290+
assert result.rewards[2][2] == 0.0
291+
292+
# Check terminated flags
293+
assert result.terminateds.shape == (3,), (
294+
"Terminated flags should be a tensor of shape (3,)"
295+
)
296+
assert all(result.terminateds == 1.0), "All terminated flags should be 1.0"
297+
298+
207299
@pytest.mark.parametrize(
208300
"multichoice_env, multichoice_test_data",
209301
[

0 commit comments

Comments
 (0)