@@ -266,10 +266,92 @@ def test_gsm8kenv(self, ray_backend, device, ray_backend_fixture):
266266 r ["history" ].full = history_full
267267 s = env .step (r )
268268 assert s .device == device
269- assert s ["next" , "reward" ] >= 10
269+ assert s ["next" , "reward" ] > 0
270270 assert s ["next" , "done" ].all ()
271271
272272
273+ class TestGSM8KRewardParser :
274+ """Unit tests for the GSM8K reward parser (no model/dataset required)."""
275+
276+ def test_extract_tags (self ):
277+ from torchrl .envs .llm .reward .gsm8k import GSM8KRewardParser
278+
279+ think , answer = GSM8KRewardParser .extract_tags (
280+ "<think>some reasoning</think> <answer>42</answer>"
281+ )
282+ assert think == "some reasoning"
283+ assert answer == "42"
284+
285+ def test_extract_tags_malformed (self ):
286+ from torchrl .envs .llm .reward .gsm8k import GSM8KRewardParser
287+
288+ think , answer = GSM8KRewardParser .extract_tags (
289+ "<think>reasoning with <special> chars & stuff</think> <answer>5</answer>"
290+ )
291+ assert answer == "5"
292+
293+ def test_extract_tags_missing (self ):
294+ from torchrl .envs .llm .reward .gsm8k import GSM8KRewardParser
295+
296+ think , answer = GSM8KRewardParser .extract_tags ("no tags here at all" )
297+ assert think == ""
298+ assert answer == ""
299+
300+ def test_normalize_answer (self ):
301+ from torchrl .envs .llm .reward .gsm8k import GSM8KRewardParser
302+
303+ assert GSM8KRewardParser .normalize_answer ("1,234" ) == "1234"
304+ assert GSM8KRewardParser .normalize_answer ("$120" ) == "120"
305+ assert GSM8KRewardParser .normalize_answer ("120.0" ) == "120"
306+ assert GSM8KRewardParser .normalize_answer ("120.00" ) == "120"
307+ assert GSM8KRewardParser .normalize_answer (" 42 " ) == "42"
308+ assert GSM8KRewardParser .normalize_answer ("3.14" ) == "3.14"
309+ assert GSM8KRewardParser .normalize_answer ("100%" ) == "100"
310+
311+ def test_correct_answer_reward (self ):
312+ from torchrl .envs .llm .reward .gsm8k import GSM8KRewardParser
313+
314+ parser = GSM8KRewardParser ()
315+ td = parser ._single_correctness_reward ("42" , "42" , "some reasoning" )
316+ assert td ["success" ]
317+ assert td ["reward" ] == 1.0
318+
319+ def test_wrong_answer_with_format (self ):
320+ from torchrl .envs .llm .reward .gsm8k import GSM8KRewardParser
321+
322+ parser = GSM8KRewardParser ()
323+ td = parser ._single_correctness_reward ("42" , "99" , "some reasoning" )
324+ assert not td ["success" ]
325+ assert td ["reward" ] == 0.1
326+ assert td ["reward_answer" ] == 1.0
327+
328+ def test_no_answer (self ):
329+ from torchrl .envs .llm .reward .gsm8k import GSM8KRewardParser
330+
331+ parser = GSM8KRewardParser ()
332+ td = parser ._single_correctness_reward ("42" , "" , "" )
333+ assert not td ["success" ]
334+ assert td ["reward" ] == 0.0
335+ assert td ["reward_answer" ] == 0.0
336+
337+ def test_normalized_match (self ):
338+ from torchrl .envs .llm .reward .gsm8k import GSM8KRewardParser
339+
340+ parser = GSM8KRewardParser ()
341+ td = parser ._single_correctness_reward ("1234" , "1,234" , "thinking" )
342+ assert td ["success" ]
343+ assert td ["reward" ] == 1.0
344+
345+ def test_custom_reward_values (self ):
346+ from torchrl .envs .llm .reward .gsm8k import GSM8KRewardParser
347+
348+ parser = GSM8KRewardParser (format_reward = 0.5 , correct_reward = 2.0 )
349+ td_correct = parser ._single_correctness_reward ("42" , "42" , "cot" )
350+ assert td_correct ["reward" ] == 2.0
351+ td_format = parser ._single_correctness_reward ("42" , "99" , "cot" )
352+ assert td_format ["reward" ] == 0.5
353+
354+
273355@pytest .mark .skipif (not _has_ifeval , reason = "requires IFEval libs" )
274356class TestIFEvalEnv :
275357 def test_ifeval (self ):
0 commit comments