Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/reference/llms_envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ The environment layer orchestrates data loading, tool execution, reward computat
:template: rl_template.rst

ChatEnv
CountdownEnv
CountdownRewardParser
DatasetChatEnv
GSM8KEnv
make_gsm8k_env
Expand Down
59 changes: 59 additions & 0 deletions test/llm/test_llm_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,65 @@ def test_no_answer(self):
assert td["reward"] == 0.0


class TestCountdownRewardParser:
"""Unit tests for the Countdown reward parser (no model/dataset required)."""

def test_validate_expression_correct(self):
from torchrl.envs.llm.reward.countdown import CountdownRewardParser

assert CountdownRewardParser.validate_expression(
"(25 + 3) * 4", 112, [25, 3, 4]
)

def test_validate_expression_wrong_result(self):
from torchrl.envs.llm.reward.countdown import CountdownRewardParser

assert not CountdownRewardParser.validate_expression("25 + 3", 100, [25, 3, 4])

def test_validate_expression_reuses_number(self):
from torchrl.envs.llm.reward.countdown import CountdownRewardParser

assert not CountdownRewardParser.validate_expression("25 + 25", 50, [25, 3, 4])

def test_validate_expression_invalid_chars(self):
from torchrl.envs.llm.reward.countdown import CountdownRewardParser

assert not CountdownRewardParser.validate_expression("import os", 0, [1, 2])

def test_parse_ground_truth(self):
from torchrl.envs.llm.reward.countdown import CountdownRewardParser

target, numbers = CountdownRewardParser._parse_ground_truth(
"target=42, numbers=10,20,5,7"
)
assert target == 42
assert numbers == [10, 20, 5, 7]

def test_correct_answer_reward(self):
from torchrl.envs.llm.reward.countdown import CountdownRewardParser

parser = CountdownRewardParser()
td = parser._single_correctness_reward(28, [25, 3], "25 + 3", "thinking")
assert td["success"]
assert td["reward"] == 1.0

def test_wrong_answer_with_format(self):
from torchrl.envs.llm.reward.countdown import CountdownRewardParser

parser = CountdownRewardParser()
td = parser._single_correctness_reward(100, [25, 3], "25 + 3", "thinking")
assert not td["success"]
assert td["reward"] == 0.1

def test_no_answer(self):
from torchrl.envs.llm.reward.countdown import CountdownRewardParser

parser = CountdownRewardParser()
td = parser._single_correctness_reward(100, [25, 3], "", "")
assert not td["success"]
assert td["reward"] == 0.0


@pytest.mark.skipif(not _has_ifeval, reason="requires IFEval libs")
class TestIFEvalRewardAggregator:
"""Unit tests for the simplified IFEval reward aggregator."""
Expand Down
11 changes: 10 additions & 1 deletion torchrl/envs/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .chat import ChatEnv, DatasetChatEnv
from .datasets import (
CountdownEnv,
GSM8KEnv,
GSM8KPrepareQuestion,
IFEvalData,
Expand All @@ -14,7 +15,13 @@
)
from .envs import LLMEnv, LLMHashingEnv
from .libs import make_mlgym, MLGymWrapper
from .reward import GSM8KRewardParser, IFEvalScoreData, IfEvalScorer, MATHRewardParser
from .reward import (
CountdownRewardParser,
GSM8KRewardParser,
IFEvalScoreData,
IfEvalScorer,
MATHRewardParser,
)
from .transforms import (
AddThinkingPrompt,
as_nested_tensor,
Expand Down Expand Up @@ -59,6 +66,8 @@
"Tokenizer",
"as_nested_tensor",
"as_padded_tensor",
"CountdownEnv",
"CountdownRewardParser",
"make_gsm8k_env",
"make_mlgym",
"MATHEnv",
Expand Down
2 changes: 2 additions & 0 deletions torchrl/envs/llm/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from .countdown import CountdownEnv
from .gsm8k import GSM8KEnv, GSM8KPrepareQuestion, make_gsm8k_env
from .ifeval import IFEvalData, IFEvalEnv, IfEvalScorer
from .math import MATHEnv

__all__ = [
"CountdownEnv",
"make_gsm8k_env",
"GSM8KPrepareQuestion",
"GSM8KEnv",
Expand Down
187 changes: 187 additions & 0 deletions torchrl/envs/llm/datasets/countdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import random
from collections.abc import Callable
from typing import Any, Literal, TYPE_CHECKING

import torch
from tensordict import TensorDict
from torch.utils.data import DataLoader, IterableDataset
from torchrl.envs import StepCounter
from torchrl.envs.llm.chat import DatasetChatEnv
from torchrl.envs.llm.reward.countdown import CountdownRewardParser

if TYPE_CHECKING:
import transformers


class _CountdownProblemGenerator(IterableDataset):
"""Infinite procedural generator for Countdown problems.

Each problem picks ``num_count`` numbers from [1, max_number] and
generates a target that is reachable from those numbers using
``+``, ``-``, ``*``, ``/``.
"""

def __init__(
self,
num_count: int = 4,
max_number: int = 100,
max_target: int = 1000,
seed: int | None = None,
):
self.num_count = num_count
self.max_number = max_number
self.max_target = max_target
self.rng = random.Random(seed)

def __iter__(self):
return self

def __next__(self) -> dict[str, Any]:
numbers = [self.rng.randint(1, self.max_number) for _ in range(self.num_count)]
target = self._make_target(numbers)
query = (
f"Using the numbers {numbers}, create an arithmetic expression that "
f"equals {target}. You may use each number at most once. "
f"Only use +, -, *, / and parentheses."
)
answer = f"target={target}, numbers={','.join(str(n) for n in numbers)}"
return {"query": query, "answer": answer}

def _make_target(self, numbers: list[int]) -> int:
"""Generate a reachable target by randomly combining numbers."""
ops = [
lambda a, b: a + b,
lambda a, b: a - b,
lambda a, b: a * b,
]
pool = list(numbers)
self.rng.shuffle(pool)
result = pool[0]
for n in pool[1:]:
op = self.rng.choice(ops)
result = op(result, n)
result = abs(result)
if result == 0:
result = sum(numbers)
if result > self.max_target:
result = sum(numbers)
return result


def _collate_fn(batch):
return torch.stack([TensorDict.from_dict(b) for b in batch])


class CountdownEnv(DatasetChatEnv):
"""Countdown numbers-game environment for LLM post-training.

Given a set of source numbers and a target, the model must construct an
arithmetic expression that evaluates to the target using each source number
at most once.

Problems are generated procedurally (no external dataset required), making
this environment ideal for quick experimentation and debugging of RL
training loops.

Keyword Args:
num_count (int): How many source numbers per problem. Defaults to ``4``.
max_number (int): Maximum value for each source number. Defaults to ``100``.
max_target (int): Ceiling for the generated target. Defaults to ``1000``.
shuffle (bool): Ignored (procedural generation is always random).
num_envs (int): Number of parallel environments. Defaults to ``1``.
repeats (int | None): Repeats per sample for MC estimation.
batch_size_dl (int): Dataloader batch size. Defaults to ``1``.
seed (int | None): Random seed for reproducibility.
group_repeats (bool): Group repeated samples. Defaults to ``False``.
tokenizer: Tokenizer for text processing.
device: Device for computation.
template_kwargs: Extra kwargs for ``apply_chat_template``.
apply_template (bool): Apply chat template. Defaults to ``False``.
compute_reward (bool): Compute rewards. Defaults to ``True``.
collate_fn: Custom collate function.
max_steps (int): Max steps per episode. Defaults to ``1``.
input_mode: ``"history"``, ``"text"``, or ``"tokens"``.

Examples:
>>> import transformers
>>> from torchrl.envs.llm.datasets.countdown import CountdownEnv
>>> tokenizer = transformers.AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
>>> env = CountdownEnv(tokenizer=tokenizer, apply_template=True, seed=42)
>>> r = env.reset()
>>> assert "history" in r

"""

SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user gives a set of "
"numbers and a target. The Assistant must find an arithmetic expression "
"using each given number at most once that equals the target.\n"
"The reasoning process and answer are enclosed within <think></think> "
"and <answer></answer> tags, respectively.\n"
"The answer should contain ONLY the arithmetic expression (e.g. "
"(25 + 3) * 4)."
)

def __init__(
self,
*,
num_count: int = 4,
max_number: int = 100,
max_target: int = 1000,
shuffle: bool = True,
num_envs: int = 1,
repeats: int | None = None,
batch_size_dl: int = 1,
seed: int | None = None,
group_repeats: bool = False,
tokenizer: transformers.AutoTokenizer | None = None, # noqa
device: torch.device | None = None,
template_kwargs: dict[str, Any] | None = None,
apply_template: bool | None = False,
compute_reward: bool = True,
collate_fn: Callable | None = None,
max_steps: int = 1,
input_mode: Literal["history", "text", "tokens"] = "history",
):
if collate_fn is None:
collate_fn = _collate_fn

self._num_count = num_count
self._max_number = max_number
self._max_target = max_target
self._seed = seed

batch_size = (num_envs,)
dataloader = DataLoader( # noqa: TOR401
_CountdownProblemGenerator(
num_count=num_count,
max_number=max_number,
max_target=max_target,
seed=seed,
),
batch_size=batch_size_dl,
collate_fn=collate_fn,
)

self._from_dataloader(
self,
dataloader=dataloader,
repeats=repeats,
device=device,
group_repeats=group_repeats,
batch_size=batch_size,
tokenizer=tokenizer,
template_kwargs=template_kwargs,
input_mode=input_mode,
)

if max_steps:
self.append_transform(StepCounter(max_steps=max_steps))
if compute_reward:
self.append_transform(CountdownRewardParser(tokenizer=tokenizer))
9 changes: 8 additions & 1 deletion torchrl/envs/llm/reward/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from .countdown import CountdownRewardParser
from .gsm8k import GSM8KRewardParser
from .ifeval import IFEvalScoreData, IfEvalScorer
from .math import MATHRewardParser

__all__ = ["IfEvalScorer", "GSM8KRewardParser", "IFEvalScoreData", "MATHRewardParser"]
__all__ = [
"CountdownRewardParser",
"IfEvalScorer",
"GSM8KRewardParser",
"IFEvalScoreData",
"MATHRewardParser",
]
Loading
Loading