Skip to content

Commit 9320fdd

Browse files
committed
Update
[ghstack-poisoned]
1 parent 9b42944 commit 9320fdd

File tree

7 files changed

+525
-2
lines changed

7 files changed

+525
-2
lines changed

docs/source/reference/llms_envs.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ The environment layer orchestrates data loading, tool execution, reward computat
2020
IFEvalEnv
2121
IfEvalScorer
2222
IFEvalScoreData
23+
MATHEnv
24+
MATHRewardParser
2325
LLMEnv
2426
LLMHashingEnv
2527
make_mlgym

test/llm/test_llm_envs.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,60 @@ def test_ifeval(self):
420420
# env.check_env_specs()
421421

422422

423+
class TestMATHRewardParser:
424+
"""Unit tests for the MATH reward parser (no model/dataset required)."""
425+
426+
def test_extract_boxed_simple(self):
427+
from torchrl.envs.llm.reward.math import MATHRewardParser
428+
429+
assert MATHRewardParser.extract_boxed(r"The answer is $\boxed{42}$.") == "42"
430+
431+
def test_extract_boxed_nested(self):
432+
from torchrl.envs.llm.reward.math import MATHRewardParser
433+
434+
assert (
435+
MATHRewardParser.extract_boxed(r"$\boxed{\frac{1}{2}}$") == r"\frac{1}{2}"
436+
)
437+
438+
def test_extract_boxed_no_boxed(self):
439+
from torchrl.envs.llm.reward.math import MATHRewardParser
440+
441+
assert MATHRewardParser.extract_boxed("no boxed here") == "no boxed here"
442+
443+
def test_extract_tags(self):
444+
from torchrl.envs.llm.reward.math import MATHRewardParser
445+
446+
think, answer = MATHRewardParser.extract_tags(
447+
r"<think>reasoning</think> <answer>\frac{1}{2}</answer>"
448+
)
449+
assert think == "reasoning"
450+
assert answer == r"\frac{1}{2}"
451+
452+
def test_correct_answer(self):
453+
from torchrl.envs.llm.reward.math import MATHRewardParser
454+
455+
parser = MATHRewardParser()
456+
td = parser._single_correctness_reward("42", "42", "reasoning")
457+
assert td["success"]
458+
assert td["reward"] == 1.0
459+
460+
def test_wrong_answer_with_format(self):
461+
from torchrl.envs.llm.reward.math import MATHRewardParser
462+
463+
parser = MATHRewardParser()
464+
td = parser._single_correctness_reward("42", "99", "reasoning")
465+
assert not td["success"]
466+
assert td["reward"] == 0.1
467+
468+
def test_no_answer(self):
469+
from torchrl.envs.llm.reward.math import MATHRewardParser
470+
471+
parser = MATHRewardParser()
472+
td = parser._single_correctness_reward("42", "", "")
473+
assert not td["success"]
474+
assert td["reward"] == 0.0
475+
476+
423477
@pytest.mark.skipif(not _has_ifeval, reason="requires IFEval libs")
424478
class TestIFEvalRewardAggregator:
425479
"""Unit tests for the simplified IFEval reward aggregator."""

torchrl/envs/llm/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
IFEvalData,
1111
IFEvalEnv,
1212
make_gsm8k_env,
13+
MATHEnv,
1314
)
1415
from .envs import LLMEnv, LLMHashingEnv
1516
from .libs import make_mlgym, MLGymWrapper
16-
from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer
17+
from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer, MATHRewardParser
1718
from .transforms import (
1819
AddThinkingPrompt,
1920
as_nested_tensor,
@@ -60,4 +61,6 @@
6061
"as_padded_tensor",
6162
"make_gsm8k_env",
6263
"make_mlgym",
64+
"MATHEnv",
65+
"MATHRewardParser",
6366
]

torchrl/envs/llm/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from .gsm8k import GSM8KEnv, GSM8KPrepareQuestion, make_gsm8k_env
88
from .ifeval import IFEvalData, IFEvalEnv, IfEvalScorer
9+
from .math import MATHEnv
910

1011
__all__ = [
1112
"make_gsm8k_env",
@@ -14,4 +15,5 @@
1415
"IFEvalEnv",
1516
"IFEvalData",
1617
"IfEvalScorer",
18+
"MATHEnv",
1719
]

torchrl/envs/llm/datasets/math.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
from collections.abc import Callable
8+
from typing import Any, Literal, TYPE_CHECKING
9+
10+
import torch
11+
from tensordict import TensorDict
12+
from torchrl.envs import StepCounter
13+
from torchrl.envs.llm.chat import DatasetChatEnv
14+
from torchrl.envs.llm.reward.math import MATHRewardParser
15+
16+
if TYPE_CHECKING:
17+
import transformers
18+
19+
20+
def _collate_fn(batch):
21+
batch = torch.stack([TensorDict.from_dict(_batch) for _batch in batch])
22+
batch.rename_key_("problem", "query")
23+
batch.rename_key_("solution", "answer")
24+
return batch
25+
26+
27+
class MATHEnv(DatasetChatEnv):
28+
r"""MATH (competition mathematics) dataset environment.
29+
30+
Uses the ``DigitalLearningGmbH/MATH-lighteval`` dataset on Hugging Face
31+
(a drop-in replacement for the original ``hendrycks/competition_math``).
32+
33+
Answers are in LaTeX ``\boxed{}`` format. When ``math-verify`` is
34+
installed the reward parser uses symbolic equivalence checking; otherwise
35+
it falls back to normalised string comparison.
36+
37+
Keyword Args:
38+
dataset (str, optional): HuggingFace dataset name.
39+
Defaults to ``"DigitalLearningGmbH/MATH-lighteval"``.
40+
shuffle (bool, optional): Shuffle the dataset. Defaults to ``True``.
41+
num_envs (int, optional): Number of parallel envs. Defaults to ``1``.
42+
repeats (int | None, optional): Repeats per sample for MC estimation.
43+
batch_size_dl (int, optional): Dataloader batch size. Defaults to ``1``.
44+
seed (int | None, optional): Random seed.
45+
group_repeats (bool, optional): Group repeated samples. Defaults to ``False``.
46+
tokenizer: Tokenizer for text processing.
47+
device: Device for computation.
48+
template_kwargs: Extra kwargs for ``apply_chat_template``.
49+
apply_template (bool): Apply chat template. Defaults to ``False``.
50+
compute_reward (bool): Compute rewards. Defaults to ``True``.
51+
collate_fn: Custom collate function.
52+
max_steps (int): Max steps per episode. Defaults to ``1``.
53+
input_mode: ``"history"``, ``"text"``, or ``"tokens"``.
54+
ray_backend (bool): Use Ray backend for data loading.
55+
dataloader_actor_name (str): Ray actor name for data loading.
56+
57+
Examples:
58+
>>> import transformers
59+
>>> from torchrl.envs.llm.datasets.math import MATHEnv
60+
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
61+
>>> env = MATHEnv(tokenizer=tokenizer, apply_template=True)
62+
>>> r = env.reset()
63+
>>> assert "history" in r
64+
65+
"""
66+
67+
SYSTEM_PROMPT = (
68+
"A conversation between User and Assistant. The user asks a math problem, "
69+
"and the Assistant solves it.\n"
70+
"The assistant first thinks about the reasoning process in the mind and "
71+
"then provides the user with the answer.\n"
72+
"The reasoning process and answer are enclosed within <think></think> and "
73+
"<answer></answer> tags, respectively, i.e.,\n"
74+
"<think>reasoning process here</think> <answer>answer here</answer>.\n"
75+
"The answer should be a mathematical expression (use LaTeX if needed)."
76+
)
77+
78+
def __init__(
79+
self,
80+
*,
81+
dataset: str = "DigitalLearningGmbH/MATH-lighteval",
82+
shuffle: bool = True,
83+
num_envs: int = 1,
84+
repeats: int | None = None,
85+
batch_size_dl: int = 1,
86+
seed: int | None = None,
87+
group_repeats: bool = False,
88+
tokenizer: transformers.AutoTokenizer | None = None, # noqa
89+
device: torch.device | None = None,
90+
template_kwargs: dict[str, Any] | None = None,
91+
apply_template: bool | None = False,
92+
compute_reward: bool = True,
93+
collate_fn: Callable | None = None,
94+
max_steps: int = 1,
95+
input_mode: Literal["history", "text", "tokens"] = "history",
96+
ray_backend: bool = False,
97+
dataloader_actor_name: str | None = None,
98+
):
99+
if ray_backend and dataloader_actor_name is None:
100+
dataloader_actor_name = "math_dataloader"
101+
if collate_fn is None:
102+
collate_fn = _collate_fn
103+
super().__init__(
104+
dataset=dataset,
105+
shuffle=shuffle,
106+
num_envs=num_envs,
107+
repeats=repeats,
108+
batch_size_dl=batch_size_dl,
109+
seed=seed,
110+
group_repeats=group_repeats,
111+
tokenizer=tokenizer,
112+
device=device,
113+
template_kwargs=template_kwargs,
114+
apply_template=apply_template,
115+
collate_fn=collate_fn,
116+
input_mode=input_mode,
117+
ray_backend=ray_backend,
118+
dataloader_actor_name=dataloader_actor_name,
119+
)
120+
if max_steps:
121+
self.append_transform(StepCounter(max_steps=max_steps))
122+
if compute_reward:
123+
self.append_transform(MATHRewardParser(tokenizer=tokenizer))

torchrl/envs/llm/reward/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@
66

77
from .gsm8k import GSM8KRewardParser
88
from .ifeval import IFEvalScoreData, IfEvalScorer
9+
from .math import MATHRewardParser
910

10-
__all__ = ["IfEvalScorer", "GSM8KRewardParser", "IFEvalScoreData"]
11+
__all__ = ["IfEvalScorer", "GSM8KRewardParser", "IFEvalScoreData", "MATHRewardParser"]

0 commit comments

Comments
 (0)