Skip to content

Commit 457915d

Browse files
committed
Update
[ghstack-poisoned]
1 parent 3d14f3a commit 457915d

File tree

4 files changed

+157
-111
lines changed

4 files changed

+157
-111
lines changed

sota-implementations/grpo/grpo_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -649,8 +649,8 @@ def make_env(cfg: DictConfig, single_env: bool = False):
649649
# Setup environment
650650
max_steps = cfg.env.max_steps if cfg.env.reasoning else 1
651651
if cfg.env.dataset == "gsm8k":
652-
# Reward scale is 0.0 to 100
653-
reward_threshold = 20
652+
# Reward scale is 0.0 to 1.0
653+
reward_threshold = 0.1
654654
env = GSM8KEnv(
655655
repeats=cfg.env.repeats,
656656
tokenizer=train_tokenizer,
@@ -659,9 +659,9 @@ def make_env(cfg: DictConfig, single_env: bool = False):
659659
device=torch.device("cpu"),
660660
ray_backend=True,
661661
)
662-
elif cfg.env.dataset == "ifeval": # ifeval
663-
# Reward scale is 0.0 to 2.2
664-
reward_threshold = 1.0
662+
elif cfg.env.dataset == "ifeval":
663+
# Reward scale is 0.0 to ~1.15
664+
reward_threshold = 0.5
665665
env = IFEvalEnv(
666666
repeats=cfg.env.repeats,
667667
tokenizer=train_tokenizer,

test/llm/test_llm_envs.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,92 @@ def test_gsm8kenv(self, ray_backend, device, ray_backend_fixture):
266266
r["history"].full = history_full
267267
s = env.step(r)
268268
assert s.device == device
269-
assert s["next", "reward"] >= 10
269+
assert s["next", "reward"] > 0
270270
assert s["next", "done"].all()
271271

272272

273+
class TestGSM8KRewardParser:
274+
"""Unit tests for the GSM8K reward parser (no model/dataset required)."""
275+
276+
def test_extract_tags(self):
277+
from torchrl.envs.llm.reward.gsm8k import GSM8KRewardParser
278+
279+
think, answer = GSM8KRewardParser.extract_tags(
280+
"<think>some reasoning</think> <answer>42</answer>"
281+
)
282+
assert think == "some reasoning"
283+
assert answer == "42"
284+
285+
def test_extract_tags_malformed(self):
286+
from torchrl.envs.llm.reward.gsm8k import GSM8KRewardParser
287+
288+
think, answer = GSM8KRewardParser.extract_tags(
289+
"<think>reasoning with <special> chars & stuff</think> <answer>5</answer>"
290+
)
291+
assert answer == "5"
292+
293+
def test_extract_tags_missing(self):
294+
from torchrl.envs.llm.reward.gsm8k import GSM8KRewardParser
295+
296+
think, answer = GSM8KRewardParser.extract_tags("no tags here at all")
297+
assert think == ""
298+
assert answer == ""
299+
300+
def test_normalize_answer(self):
301+
from torchrl.envs.llm.reward.gsm8k import GSM8KRewardParser
302+
303+
assert GSM8KRewardParser.normalize_answer("1,234") == "1234"
304+
assert GSM8KRewardParser.normalize_answer("$120") == "120"
305+
assert GSM8KRewardParser.normalize_answer("120.0") == "120"
306+
assert GSM8KRewardParser.normalize_answer("120.00") == "120"
307+
assert GSM8KRewardParser.normalize_answer(" 42 ") == "42"
308+
assert GSM8KRewardParser.normalize_answer("3.14") == "3.14"
309+
assert GSM8KRewardParser.normalize_answer("100%") == "100"
310+
311+
def test_correct_answer_reward(self):
312+
from torchrl.envs.llm.reward.gsm8k import GSM8KRewardParser
313+
314+
parser = GSM8KRewardParser()
315+
td = parser._single_correctness_reward("42", "42", "some reasoning")
316+
assert td["success"]
317+
assert td["reward"] == 1.0
318+
319+
def test_wrong_answer_with_format(self):
320+
from torchrl.envs.llm.reward.gsm8k import GSM8KRewardParser
321+
322+
parser = GSM8KRewardParser()
323+
td = parser._single_correctness_reward("42", "99", "some reasoning")
324+
assert not td["success"]
325+
assert td["reward"] == 0.1
326+
assert td["reward_answer"] == 1.0
327+
328+
def test_no_answer(self):
329+
from torchrl.envs.llm.reward.gsm8k import GSM8KRewardParser
330+
331+
parser = GSM8KRewardParser()
332+
td = parser._single_correctness_reward("42", "", "")
333+
assert not td["success"]
334+
assert td["reward"] == 0.0
335+
assert td["reward_answer"] == 0.0
336+
337+
def test_normalized_match(self):
338+
from torchrl.envs.llm.reward.gsm8k import GSM8KRewardParser
339+
340+
parser = GSM8KRewardParser()
341+
td = parser._single_correctness_reward("1234", "1,234", "thinking")
342+
assert td["success"]
343+
assert td["reward"] == 1.0
344+
345+
def test_custom_reward_values(self):
346+
from torchrl.envs.llm.reward.gsm8k import GSM8KRewardParser
347+
348+
parser = GSM8KRewardParser(format_reward=0.5, correct_reward=2.0)
349+
td_correct = parser._single_correctness_reward("42", "42", "cot")
350+
assert td_correct["reward"] == 2.0
351+
td_format = parser._single_correctness_reward("42", "99", "cot")
352+
assert td_format["reward"] == 0.5
353+
354+
273355
@pytest.mark.skipif(not _has_ifeval, reason="requires IFEval libs")
274356
class TestIFEvalEnv:
275357
def test_ifeval(self):

torchrl/envs/llm/datasets/gsm8k.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ class GSM8KEnv(DatasetChatEnv):
259259
is_shared=False),
260260
reward: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
261261
reward_answer: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
262-
reward_contained: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
263262
reward_right: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
264263
reward_think: Tensor(shape=torch.Size([1, 1, 1]), device=cpu, dtype=torch.float32, is_shared=False),
265264
step_count: Tensor(shape=torch.Size([1, 1]), device=cpu, dtype=torch.int64, is_shared=False),
@@ -293,7 +292,7 @@ class GSM8KEnv(DatasetChatEnv):
293292
device=None,
294293
is_shared=False,
295294
stack_dim=0)
296-
>>> assert s["next", "reward"] >= 10
295+
>>> assert s["next", "reward"] > 0
297296
>>> assert s["next", "done"].all()
298297
299298
"""

0 commit comments

Comments
 (0)