Skip to content

Commit 76f0798

Browse files
committed
[LLM] Add Countdown numbers-game environment
Add CountdownEnv and CountdownRewardParser for the Countdown numbers game, a popular lightweight problem for GRPO training. Key features: - Procedural problem generation (no external dataset required) - Validates arithmetic expressions: correct operators, each source number used at most once, evaluates to target - Same 0/0.1/1.0 reward convention as GSM8K and MATH parsers - Configurable problem difficulty (num_count, max_number, max_target) - Includes unit tests and documentation Made-with: Cursor ghstack-source-id: 80b4e64 Pull-Request: #3545
1 parent 004af28 commit 76f0798

File tree

7 files changed

+591
-2
lines changed

7 files changed

+591
-2
lines changed

docs/source/reference/llms_envs.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ The environment layer orchestrates data loading, tool execution, reward computat
1212
:template: rl_template.rst
1313

1414
ChatEnv
15+
CountdownEnv
16+
CountdownRewardParser
1517
DatasetChatEnv
1618
GSM8KEnv
1719
make_gsm8k_env

test/llm/test_llm_envs.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,65 @@ def test_no_answer(self):
474474
assert td["reward"] == 0.0
475475

476476

477+
class TestCountdownRewardParser:
478+
"""Unit tests for the Countdown reward parser (no model/dataset required)."""
479+
480+
def test_validate_expression_correct(self):
481+
from torchrl.envs.llm.reward.countdown import CountdownRewardParser
482+
483+
assert CountdownRewardParser.validate_expression(
484+
"(25 + 3) * 4", 112, [25, 3, 4]
485+
)
486+
487+
def test_validate_expression_wrong_result(self):
488+
from torchrl.envs.llm.reward.countdown import CountdownRewardParser
489+
490+
assert not CountdownRewardParser.validate_expression("25 + 3", 100, [25, 3, 4])
491+
492+
def test_validate_expression_reuses_number(self):
493+
from torchrl.envs.llm.reward.countdown import CountdownRewardParser
494+
495+
assert not CountdownRewardParser.validate_expression("25 + 25", 50, [25, 3, 4])
496+
497+
def test_validate_expression_invalid_chars(self):
498+
from torchrl.envs.llm.reward.countdown import CountdownRewardParser
499+
500+
assert not CountdownRewardParser.validate_expression("import os", 0, [1, 2])
501+
502+
def test_parse_ground_truth(self):
503+
from torchrl.envs.llm.reward.countdown import CountdownRewardParser
504+
505+
target, numbers = CountdownRewardParser._parse_ground_truth(
506+
"target=42, numbers=10,20,5,7"
507+
)
508+
assert target == 42
509+
assert numbers == [10, 20, 5, 7]
510+
511+
def test_correct_answer_reward(self):
512+
from torchrl.envs.llm.reward.countdown import CountdownRewardParser
513+
514+
parser = CountdownRewardParser()
515+
td = parser._single_correctness_reward(28, [25, 3], "25 + 3", "thinking")
516+
assert td["success"]
517+
assert td["reward"] == 1.0
518+
519+
def test_wrong_answer_with_format(self):
520+
from torchrl.envs.llm.reward.countdown import CountdownRewardParser
521+
522+
parser = CountdownRewardParser()
523+
td = parser._single_correctness_reward(100, [25, 3], "25 + 3", "thinking")
524+
assert not td["success"]
525+
assert td["reward"] == 0.1
526+
527+
def test_no_answer(self):
528+
from torchrl.envs.llm.reward.countdown import CountdownRewardParser
529+
530+
parser = CountdownRewardParser()
531+
td = parser._single_correctness_reward(100, [25, 3], "", "")
532+
assert not td["success"]
533+
assert td["reward"] == 0.0
534+
535+
477536
@pytest.mark.skipif(not _has_ifeval, reason="requires IFEval libs")
478537
class TestIFEvalRewardAggregator:
479538
"""Unit tests for the simplified IFEval reward aggregator."""

torchrl/envs/llm/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from .chat import ChatEnv, DatasetChatEnv
77
from .datasets import (
8+
CountdownEnv,
89
GSM8KEnv,
910
GSM8KPrepareQuestion,
1011
IFEvalData,
@@ -14,7 +15,13 @@
1415
)
1516
from .envs import LLMEnv, LLMHashingEnv
1617
from .libs import make_mlgym, MLGymWrapper
17-
from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer, MATHRewardParser
18+
from .reward import (
19+
CountdownRewardParser,
20+
GSM8KRewardParser,
21+
IFEvalScoreData,
22+
IfEvalScorer,
23+
MATHRewardParser,
24+
)
1825
from .transforms import (
1926
AddThinkingPrompt,
2027
as_nested_tensor,
@@ -59,6 +66,8 @@
5966
"Tokenizer",
6067
"as_nested_tensor",
6168
"as_padded_tensor",
69+
"CountdownEnv",
70+
"CountdownRewardParser",
6271
"make_gsm8k_env",
6372
"make_mlgym",
6473
"MATHEnv",

torchrl/envs/llm/datasets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7+
from .countdown import CountdownEnv
78
from .gsm8k import GSM8KEnv, GSM8KPrepareQuestion, make_gsm8k_env
89
from .ifeval import IFEvalData, IFEvalEnv, IfEvalScorer
910
from .math import MATHEnv
1011

1112
__all__ = [
13+
"CountdownEnv",
1214
"make_gsm8k_env",
1315
"GSM8KPrepareQuestion",
1416
"GSM8KEnv",
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
import random
8+
from collections.abc import Callable
9+
from typing import Any, Literal, TYPE_CHECKING
10+
11+
import torch
12+
from tensordict import TensorDict
13+
from torch.utils.data import DataLoader, IterableDataset
14+
from torchrl.envs import StepCounter
15+
from torchrl.envs.llm.chat import DatasetChatEnv
16+
from torchrl.envs.llm.reward.countdown import CountdownRewardParser
17+
18+
if TYPE_CHECKING:
19+
import transformers
20+
21+
22+
class _CountdownProblemGenerator(IterableDataset):
23+
"""Infinite procedural generator for Countdown problems.
24+
25+
Each problem picks ``num_count`` numbers from [1, max_number] and
26+
generates a target that is reachable from those numbers using
27+
``+``, ``-``, ``*``, ``/``.
28+
"""
29+
30+
def __init__(
31+
self,
32+
num_count: int = 4,
33+
max_number: int = 100,
34+
max_target: int = 1000,
35+
seed: int | None = None,
36+
):
37+
self.num_count = num_count
38+
self.max_number = max_number
39+
self.max_target = max_target
40+
self.rng = random.Random(seed)
41+
42+
def __iter__(self):
43+
return self
44+
45+
def __next__(self) -> dict[str, Any]:
46+
numbers = [self.rng.randint(1, self.max_number) for _ in range(self.num_count)]
47+
target = self._make_target(numbers)
48+
query = (
49+
f"Using the numbers {numbers}, create an arithmetic expression that "
50+
f"equals {target}. You may use each number at most once. "
51+
f"Only use +, -, *, / and parentheses."
52+
)
53+
answer = f"target={target}, numbers={','.join(str(n) for n in numbers)}"
54+
return {"query": query, "answer": answer}
55+
56+
def _make_target(self, numbers: list[int]) -> int:
57+
"""Generate a reachable target by randomly combining numbers."""
58+
ops = [
59+
lambda a, b: a + b,
60+
lambda a, b: a - b,
61+
lambda a, b: a * b,
62+
]
63+
pool = list(numbers)
64+
self.rng.shuffle(pool)
65+
result = pool[0]
66+
for n in pool[1:]:
67+
op = self.rng.choice(ops)
68+
result = op(result, n)
69+
result = abs(result)
70+
if result == 0:
71+
result = sum(numbers)
72+
if result > self.max_target:
73+
result = sum(numbers)
74+
return result
75+
76+
77+
def _collate_fn(batch):
78+
return torch.stack([TensorDict.from_dict(b) for b in batch])
79+
80+
81+
class CountdownEnv(DatasetChatEnv):
82+
"""Countdown numbers-game environment for LLM post-training.
83+
84+
Given a set of source numbers and a target, the model must construct an
85+
arithmetic expression that evaluates to the target using each source number
86+
at most once.
87+
88+
Problems are generated procedurally (no external dataset required), making
89+
this environment ideal for quick experimentation and debugging of RL
90+
training loops.
91+
92+
Keyword Args:
93+
num_count (int): How many source numbers per problem. Defaults to ``4``.
94+
max_number (int): Maximum value for each source number. Defaults to ``100``.
95+
max_target (int): Ceiling for the generated target. Defaults to ``1000``.
96+
shuffle (bool): Ignored (procedural generation is always random).
97+
num_envs (int): Number of parallel environments. Defaults to ``1``.
98+
repeats (int | None): Repeats per sample for MC estimation.
99+
batch_size_dl (int): Dataloader batch size. Defaults to ``1``.
100+
seed (int | None): Random seed for reproducibility.
101+
group_repeats (bool): Group repeated samples. Defaults to ``False``.
102+
tokenizer: Tokenizer for text processing.
103+
device: Device for computation.
104+
template_kwargs: Extra kwargs for ``apply_chat_template``.
105+
apply_template (bool): Apply chat template. Defaults to ``False``.
106+
compute_reward (bool): Compute rewards. Defaults to ``True``.
107+
collate_fn: Custom collate function.
108+
max_steps (int): Max steps per episode. Defaults to ``1``.
109+
input_mode: ``"history"``, ``"text"``, or ``"tokens"``.
110+
111+
Examples:
112+
>>> import transformers
113+
>>> from torchrl.envs.llm.datasets.countdown import CountdownEnv
114+
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
115+
>>> env = CountdownEnv(tokenizer=tokenizer, apply_template=True, seed=42)
116+
>>> r = env.reset()
117+
>>> assert "history" in r
118+
119+
"""
120+
121+
SYSTEM_PROMPT = (
122+
"A conversation between User and Assistant. The user gives a set of "
123+
"numbers and a target. The Assistant must find an arithmetic expression "
124+
"using each given number at most once that equals the target.\n"
125+
"The reasoning process and answer are enclosed within <think></think> "
126+
"and <answer></answer> tags, respectively.\n"
127+
"The answer should contain ONLY the arithmetic expression (e.g. "
128+
"(25 + 3) * 4)."
129+
)
130+
131+
def __init__(
132+
self,
133+
*,
134+
num_count: int = 4,
135+
max_number: int = 100,
136+
max_target: int = 1000,
137+
shuffle: bool = True,
138+
num_envs: int = 1,
139+
repeats: int | None = None,
140+
batch_size_dl: int = 1,
141+
seed: int | None = None,
142+
group_repeats: bool = False,
143+
tokenizer: transformers.AutoTokenizer | None = None, # noqa
144+
device: torch.device | None = None,
145+
template_kwargs: dict[str, Any] | None = None,
146+
apply_template: bool | None = False,
147+
compute_reward: bool = True,
148+
collate_fn: Callable | None = None,
149+
max_steps: int = 1,
150+
input_mode: Literal["history", "text", "tokens"] = "history",
151+
):
152+
if collate_fn is None:
153+
collate_fn = _collate_fn
154+
155+
self._num_count = num_count
156+
self._max_number = max_number
157+
self._max_target = max_target
158+
self._seed = seed
159+
160+
batch_size = (num_envs,)
161+
dataloader = DataLoader( # noqa: TOR401
162+
_CountdownProblemGenerator(
163+
num_count=num_count,
164+
max_number=max_number,
165+
max_target=max_target,
166+
seed=seed,
167+
),
168+
batch_size=batch_size_dl,
169+
collate_fn=collate_fn,
170+
)
171+
172+
self._from_dataloader(
173+
self,
174+
dataloader=dataloader,
175+
repeats=repeats,
176+
device=device,
177+
group_repeats=group_repeats,
178+
batch_size=batch_size,
179+
tokenizer=tokenizer,
180+
template_kwargs=template_kwargs,
181+
input_mode=input_mode,
182+
)
183+
184+
if max_steps:
185+
self.append_transform(StepCounter(max_steps=max_steps))
186+
if compute_reward:
187+
self.append_transform(CountdownRewardParser(tokenizer=tokenizer))

torchrl/envs/llm/reward/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,15 @@
44
# LICENSE file in the root directory of this source tree.
55
from __future__ import annotations
66

7+
from .countdown import CountdownRewardParser
78
from .gsm8k import GSM8KRewardParser
89
from .ifeval import IFEvalScoreData, IfEvalScorer
910
from .math import MATHRewardParser
1011

11-
__all__ = ["IfEvalScorer", "GSM8KRewardParser", "IFEvalScoreData", "MATHRewardParser"]
12+
__all__ = [
13+
"CountdownRewardParser",
14+
"IfEvalScorer",
15+
"GSM8KRewardParser",
16+
"IFEvalScoreData",
17+
"MATHRewardParser",
18+
]

0 commit comments

Comments
 (0)