1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14- import os
1514import time
1615
1716import pytest
1817import ray
1918
20- from nemo_rl .distributed .ray_actor_environment_registry import (
21- get_actor_python_env ,
22- )
23- from nemo_rl .environments .math_environment import MathEnvironment
19+ from nemo_rl .environments .utils import create_env
20+
21+ # ============================================================================
22+ # Environment fixtures
23+ # ============================================================================
2424
2525
2626@pytest .fixture (scope = "module" )
2727def math_env ():
2828 """Create a MathEnvironment actor for testing."""
29- env = MathEnvironment .options (
30- runtime_env = {
31- "py_executable" : get_actor_python_env (
32- "nemo_rl.environments.math_environment.MathEnvironment"
33- ),
34- "env_vars" : dict (os .environ ),
35- }
36- ).remote ({"num_workers" : 2 })
29+ env = create_env ("math" , {"num_workers" : 2 })
30+ yield env
31+ # Clean up the actor and wait for it to be killed
32+ env .shutdown .remote ()
33+ ray .kill (env )
34+ # Give some time for cleanup
35+ time .sleep (0.1 )
36+
37+
38+ @pytest .fixture (scope = "module" )
39+ def math_multi_reward_env ():
40+ """Create a MathMultiRewardEnvironment actor for testing."""
41+ env = create_env ("math_multi_reward" , {"num_workers" : 2 })
3742 yield env
3843 # Clean up the actor and wait for it to be killed
3944 env .shutdown .remote ()
@@ -45,15 +50,7 @@ def math_env():
4550@pytest .fixture (scope = "module" )
4651def multichoice_env (request ):
4752 """Create a MathEnvironment actor for testing."""
48- verifier_type = request .param
49- env = MathEnvironment .options (
50- runtime_env = {
51- "py_executable" : get_actor_python_env (
52- "nemo_rl.environments.math_environment.MathEnvironment"
53- ),
54- "env_vars" : dict (os .environ ),
55- }
56- ).remote ({"num_workers" : 2 , "verifier_type" : verifier_type })
53+ env = create_env ("math" , {"num_workers" : 2 , "verifier_type" : request .param })
5754 yield env
5855 # Clean up the actor and wait for it to be killed
5956 env .shutdown .remote ()
@@ -62,6 +59,11 @@ def multichoice_env(request):
6259 time .sleep (0.1 )
6360
6461
62+ # ============================================================================
63+ # Data fixtures
64+ # ============================================================================
65+
66+
6567@pytest .fixture
6668def basic_test_data ():
6769 """Common test data for basic math problems."""
@@ -88,6 +90,41 @@ def basic_test_data():
8890 }
8991
9092
93+ @pytest .fixture
94+ def multi_reward_test_data ():
95+ """Common test data for basic math problems with multiple rewards."""
96+ return {
97+ "message_log_batch" : [
98+ [
99+ {"role" : "user" , "content" : "What is 2 + 2?" },
100+ {
101+ "role" : "assistant" ,
102+ "content" : "<think>2 + 2 = 4</think>\n <answer>4</answer>" ,
103+ },
104+ ],
105+ [
106+ {"role" : "user" , "content" : "What is 3 * 4?" },
107+ {
108+ "role" : "assistant" ,
109+ "content" : "<think>3 * 4 = 12</think>\n <answer>12.5</answer>" ,
110+ },
111+ ],
112+ [
113+ {"role" : "user" , "content" : "What is 10 - 5?" },
114+ {
115+ "role" : "assistant" ,
116+ "content" : "<think>10 - 5 = 5\n <answer>5</answer>" ,
117+ },
118+ ],
119+ ],
120+ "metadata" : [
121+ {"ground_truth" : "4" },
122+ {"ground_truth" : "12" },
123+ {"ground_truth" : "5" },
124+ ],
125+ }
126+
127+
91128@pytest .fixture
92129def multichoice_test_data (request ):
93130 """Common test data for basic multichoice problems."""
@@ -170,6 +207,11 @@ def multiple_assistant_test_data():
170207 }
171208
172209
210+ # ============================================================================
211+ # Environment tests
212+ # ============================================================================
213+
214+
173215def test_math_env_step_basic (math_env , basic_test_data ):
174216 """Test basic functionality of MathEnvironment step with simple messages."""
175217 result = ray .get (
@@ -204,6 +246,56 @@ def test_math_env_step_basic(math_env, basic_test_data):
204246 assert all (result .terminateds == 1.0 ), "All terminated flags should be 1.0"
205247
206248
249+ def test_multi_reward_env_step_basic (math_multi_reward_env , multi_reward_test_data ):
250+ """Test basic step: correct answer + valid format -> all 3 rewards 1.0."""
251+ result = ray .get (
252+ math_multi_reward_env .step .remote (
253+ multi_reward_test_data ["message_log_batch" ],
254+ multi_reward_test_data ["metadata" ],
255+ )
256+ )
257+
258+ # Check observations (based on correctness reward, index 0)
259+ assert len (result .observations ) == 3 , (
260+ "Should return observations for all 3 messages"
261+ )
262+ assert all (obs ["role" ] == "environment" for obs in result .observations ), (
263+ "All observations should be from environment"
264+ )
265+
266+ # Check observations for each data point
267+ assert result .observations [0 ]["content" ] == "Environment: correct"
268+ assert result .observations [1 ]["content" ] == "Environment: incorrect"
269+ assert result .observations [2 ]["content" ] == "Environment: correct"
270+
271+ # Check metadata
272+ assert len (result .metadata ) == 3 , "Should return metadata for all 3 messages"
273+ assert result .metadata == multi_reward_test_data ["metadata" ], (
274+ "Metadata should be unchanged"
275+ )
276+
277+ # Check rewards: shape (batch_size=3, number_of_rewards=3)
278+ assert result .rewards .shape == (3 , 3 ), "Rewards should be a tensor of shape (3, 3)"
279+
280+ # Check rewards for each data point
281+ # First reward: correctness reward 1.0, int reward 1.0, format reward 1.0
282+ assert (result .rewards [0 ] == 1.0 ).all (), "First reward should be 1.0"
283+ # Second reward: correctness reward 0.0, int reward 0.0, format reward 1.0
284+ assert result .rewards [1 ][0 ] == 0.0
285+ assert result .rewards [1 ][1 ] == 0.0
286+ assert result .rewards [1 ][2 ] == 1.0
287+ # Third reward: correctness reward 1.0, int reward 1.0, format reward 0.0
288+ assert result .rewards [2 ][0 ] == 1.0
289+ assert result .rewards [2 ][1 ] == 1.0
290+ assert result .rewards [2 ][2 ] == 0.0
291+
292+ # Check terminated flags
293+ assert result .terminateds .shape == (3 ,), (
294+ "Terminated flags should be a tensor of shape (3,)"
295+ )
296+ assert all (result .terminateds == 1.0 ), "All terminated flags should be 1.0"
297+
298+
207299@pytest .mark .parametrize (
208300 "multichoice_env, multichoice_test_data" ,
209301 [
0 commit comments