Skip to content

Commit 828dd54

Browse files
committed
[LLM] Wire MATH and Countdown into GRPO and Expert Iteration scripts
Add GRPO config files for math and countdown datasets, and update grpo_utils.py and ei_utils.py to support the new dataset choices. All four datasets (gsm8k, ifeval, math, countdown) are now selectable via the env.dataset config key. Made-with: Cursor ghstack-source-id: 8713155 Pull-Request: #3546
1 parent 76f0798 commit 828dd54

File tree

5 files changed

+259
-32
lines changed

5 files changed

+259
-32
lines changed

sota-implementations/expert-iteration/ei_utils.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
from torchrl._utils import logger as torchrl_logger
1717
from torchrl.envs.llm import RetrieveLogProb
18+
from torchrl.envs.llm.datasets.countdown import CountdownEnv
1819
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
20+
from torchrl.envs.llm.datasets.math import MATHEnv
1921
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
2022
from torchrl.weight_update.llm import VLLMWeightSyncScheme
2123
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
@@ -63,22 +65,24 @@ def make_env(cfg: DictConfig, devices: list[int] | None = None):
6365
ref_model = get_ref_model(ref_cfg, train_tokenizer, devices=devices)
6466

6567
# Setup environment
68+
common_kwargs = {
69+
"repeats": cfg.env.repeats,
70+
"tokenizer": train_tokenizer,
71+
"num_envs": cfg.env.num_envs,
72+
"device": torch.device("cpu"),
73+
}
6674
if cfg.env.dataset == "gsm8k":
6775
from torchrl.envs.llm import GSM8KEnv
6876

69-
env = GSM8KEnv(
70-
repeats=cfg.env.repeats,
71-
tokenizer=train_tokenizer,
72-
num_envs=cfg.env.num_envs,
73-
device=torch.device("cpu"),
74-
)
75-
else: # ifeval
76-
env = IFEvalEnv(
77-
repeats=cfg.env.repeats,
78-
tokenizer=train_tokenizer,
79-
num_envs=cfg.env.num_envs,
80-
device=torch.device("cpu"),
81-
)
77+
env = GSM8KEnv(**common_kwargs)
78+
elif cfg.env.dataset == "ifeval":
79+
env = IFEvalEnv(**common_kwargs)
80+
elif cfg.env.dataset == "math":
81+
env = MATHEnv(**common_kwargs)
82+
elif cfg.env.dataset == "countdown":
83+
env = CountdownEnv(**common_kwargs)
84+
else:
85+
raise NotImplementedError(f"Dataset {cfg.env.dataset} not implemented")
8286

8387
# Pass device directly to RetrieveLogProb - Since, for Ray, the local device is always 0
8488
# we can just use 0 here.
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# @package _global_
2+
defaults:
3+
- mode: ${mode:async}
4+
- _self_
5+
- override hydra/hydra_logging: disabled
6+
- override hydra/job_logging: disabled
7+
8+
env:
9+
dataset: countdown
10+
num_envs: 32
11+
repeats: 16
12+
reasoning: false
13+
max_steps: 2
14+
15+
model:
16+
name: Qwen/Qwen2.5-3B
17+
compile: false
18+
19+
train:
20+
exp_name: "grpo-countdown"
21+
mixed_precision: true
22+
total_dialog_turns: 100_000
23+
packing: false
24+
dialog_turns_per_batch: 32
25+
gradient_accumulation_steps: 8
26+
checkpoint_frequency: 100
27+
optim_batch_size: 32
28+
kl_coef_in_loss: true
29+
use_kl_to_ref: false
30+
kl_to_ref_coeff: 0.0
31+
kl_to_inference_coeff: 1e-2
32+
entropy_coeff: 1e-4
33+
logging_frequency: 10
34+
empty_replay_buffer: true
35+
36+
train_model:
37+
gradient_checkpointing: true
38+
num_devices: 1
39+
lora:
40+
enabled: true
41+
r: 8
42+
alpha: 16
43+
dropout: 0.1
44+
quantization:
45+
enabled: false
46+
attn_implementation: sdpa
47+
torch_dtype: bfloat16
48+
49+
inference_model:
50+
num_devices: 1
51+
quantization:
52+
enabled: false
53+
attn_implementation: sdpa
54+
torch_dtype: bfloat16
55+
gpu_memory_utilization: 0.9
56+
temperature: 1.0
57+
top_p: 0.95
58+
max_tokens: 512
59+
include_stop_str_in_output: true
60+
enforce_eager: false
61+
62+
ref_model:
63+
gradient_checkpointing: false
64+
num_devices: 1
65+
lora:
66+
enabled: true
67+
r: 8
68+
alpha: 16
69+
dropout: 0.1
70+
quantization:
71+
enabled: false
72+
attn_implementation: sdpa
73+
torch_dtype: bfloat16
74+
75+
optimizer:
76+
name: AdamW
77+
lr: 1e-5
78+
clip_grad_norm: 1.0
79+
weight_decay: 0.0
80+
81+
ray:
82+
init_config:
83+
num_cpus: 96
84+
num_gpus: 8
85+
runtime_env:
86+
working_dir: "."
87+
_temp_dir: "/tmp/ray_grpo"
88+
_system_config:
89+
object_spilling_threshold: 0.8
90+
max_direct_memory_size: 10 * 1024 * 1024 * 1024
91+
object_store_full_delay_ms: 100
92+
object_store_full_max_retries: 3
93+
collector_config:
94+
num_cpus: 4
95+
train_handler_config:
96+
num_cpus: 4
97+
replay_buffer_config:
98+
num_cpus: 4
99+
num_gpus: 0.0
100+
101+
logging:
102+
experiment_name: null
103+
checkpoint_dir: "checkpoints"
104+
checkpoint_frequency: 10
105+
106+
hydra:
107+
run:
108+
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
109+
sweep:
110+
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
111+
subdir: ${hydra.job.num}

sota-implementations/grpo/config/grpo_gsm8k.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ defaults:
77

88
# Environment configuration
99
env:
10-
dataset: gsm8k # choices: [gsm8k, ifeval]
10+
dataset: gsm8k # choices: [gsm8k, ifeval, math, countdown]
1111
# Number of environments to run in parallel. This determines the batch size passed to vLLM.
1212
# More envs do not consume more GPU memory but there will be a sync on the call to vLLM.
1313
num_envs: 32
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# @package _global_
2+
defaults:
3+
- mode: ${mode:async}
4+
- _self_
5+
- override hydra/hydra_logging: disabled
6+
- override hydra/job_logging: disabled
7+
8+
env:
9+
dataset: math
10+
num_envs: 32
11+
repeats: 16
12+
reasoning: false
13+
max_steps: 2
14+
15+
model:
16+
name: Qwen/Qwen2.5-3B
17+
compile: false
18+
19+
train:
20+
exp_name: "grpo-math"
21+
mixed_precision: true
22+
total_dialog_turns: 100_000
23+
packing: false
24+
dialog_turns_per_batch: 32
25+
gradient_accumulation_steps: 8
26+
checkpoint_frequency: 100
27+
optim_batch_size: 32
28+
kl_coef_in_loss: true
29+
use_kl_to_ref: false
30+
kl_to_ref_coeff: 0.0
31+
kl_to_inference_coeff: 1e-2
32+
entropy_coeff: 1e-4
33+
logging_frequency: 10
34+
empty_replay_buffer: true
35+
36+
train_model:
37+
gradient_checkpointing: true
38+
num_devices: 1
39+
lora:
40+
enabled: true
41+
r: 8
42+
alpha: 16
43+
dropout: 0.1
44+
quantization:
45+
enabled: false
46+
attn_implementation: sdpa
47+
torch_dtype: bfloat16
48+
49+
inference_model:
50+
num_devices: 1
51+
quantization:
52+
enabled: false
53+
attn_implementation: sdpa
54+
torch_dtype: bfloat16
55+
gpu_memory_utilization: 0.9
56+
temperature: 1.0
57+
top_p: 0.95
58+
max_tokens: 1024
59+
include_stop_str_in_output: true
60+
enforce_eager: false
61+
62+
ref_model:
63+
gradient_checkpointing: false
64+
num_devices: 1
65+
lora:
66+
enabled: true
67+
r: 8
68+
alpha: 16
69+
dropout: 0.1
70+
quantization:
71+
enabled: false
72+
attn_implementation: sdpa
73+
torch_dtype: bfloat16
74+
75+
optimizer:
76+
name: AdamW
77+
lr: 1e-5
78+
clip_grad_norm: 1.0
79+
weight_decay: 0.0
80+
81+
ray:
82+
init_config:
83+
num_cpus: 96
84+
num_gpus: 8
85+
runtime_env:
86+
working_dir: "."
87+
_temp_dir: "/tmp/ray_grpo"
88+
_system_config:
89+
object_spilling_threshold: 0.8
90+
max_direct_memory_size: 10 * 1024 * 1024 * 1024
91+
object_store_full_delay_ms: 100
92+
object_store_full_max_retries: 3
93+
collector_config:
94+
num_cpus: 4
95+
train_handler_config:
96+
num_cpus: 4
97+
replay_buffer_config:
98+
num_cpus: 4
99+
num_gpus: 0.0
100+
101+
logging:
102+
experiment_name: null
103+
checkpoint_dir: "checkpoints"
104+
checkpoint_frequency: 10
105+
106+
hydra:
107+
run:
108+
dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S}
109+
sweep:
110+
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
111+
subdir: ${hydra.job.num}

sota-implementations/grpo/grpo_utils.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
from torchrl._utils import logger as torchrl_logger, timeit
1818
from torchrl.envs.llm import AddThinkingPrompt, GSM8KEnv, KLRewardTransform, RetrieveKL
19+
from torchrl.envs.llm.datasets.countdown import CountdownEnv
1920
from torchrl.envs.llm.datasets.ifeval import IFEvalEnv
21+
from torchrl.envs.llm.datasets.math import MATHEnv
2022
from torchrl.modules.llm import TransformersWrapper, vLLMWrapper
2123
from torchrl.weight_update.llm import VLLMWeightSyncScheme
2224
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
@@ -648,28 +650,27 @@ def make_env(cfg: DictConfig, single_env: bool = False):
648650

649651
# Setup environment
650652
max_steps = cfg.env.max_steps if cfg.env.reasoning else 1
653+
num_envs = cfg.env.num_envs if not single_env else 1
654+
common_kwargs = {
655+
"repeats": cfg.env.repeats,
656+
"tokenizer": train_tokenizer,
657+
"num_envs": num_envs,
658+
"max_steps": max_steps,
659+
"device": torch.device("cpu"),
660+
}
661+
651662
if cfg.env.dataset == "gsm8k":
652-
# Reward scale is 0.0 to 1.0
653663
reward_threshold = 0.1
654-
env = GSM8KEnv(
655-
repeats=cfg.env.repeats,
656-
tokenizer=train_tokenizer,
657-
num_envs=cfg.env.num_envs if not single_env else 1,
658-
max_steps=max_steps,
659-
device=torch.device("cpu"),
660-
ray_backend=True,
661-
)
664+
env = GSM8KEnv(**common_kwargs, ray_backend=True)
662665
elif cfg.env.dataset == "ifeval":
663-
# Reward scale is 0.0 to ~1.15
664666
reward_threshold = 0.5
665-
env = IFEvalEnv(
666-
repeats=cfg.env.repeats,
667-
tokenizer=train_tokenizer,
668-
num_envs=cfg.env.num_envs if not single_env else 1,
669-
max_steps=max_steps,
670-
device=torch.device("cpu"),
671-
ray_backend=True,
672-
)
667+
env = IFEvalEnv(**common_kwargs, ray_backend=True)
668+
elif cfg.env.dataset == "math":
669+
reward_threshold = 0.1
670+
env = MATHEnv(**common_kwargs, ray_backend=True)
671+
elif cfg.env.dataset == "countdown":
672+
reward_threshold = 0.1
673+
env = CountdownEnv(**common_kwargs)
673674
else:
674675
raise NotImplementedError(f"Dataset {cfg.env.dataset} not implemented")
675676

0 commit comments

Comments
 (0)