Skip to content

Commit 4ef2fcf

Browse files
committed
squash (remove dependency update)
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent c40dba3 commit 4ef2fcf

File tree

14 files changed

+1095
-71
lines changed

14 files changed

+1095
-71
lines changed

examples/configs/gdpo_math_1B.yaml

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# GDPO: inherits from grpo_math_1B.yaml and overrides only what differs.
2+
defaults: grpo_math_1B.yaml
3+
4+
grpo:
5+
adv_estimator:
6+
name: "gdpo"
7+
normalize_rewards: true
8+
use_leave_one_out_baseline: false
9+
10+
checkpointing:
11+
checkpoint_dir: "results/gdpo"
12+
13+
policy:
14+
model_name: "Qwen/Qwen2.5-1.5B-Instruct"
15+
logprob_batch_size: 4
16+
max_total_sequence_length: 1024
17+
megatron_cfg:
18+
optimizer:
19+
weight_decay: 0.0
20+
scheduler:
21+
lr_decay_style: "cosine"
22+
lr_warmup_iters: 10
23+
24+
# GDPO uses a single flat data config (GSM8K + math_gdpo_data_processor); replace parent's train/validation/default.
25+
data:
26+
_override_: true
27+
max_input_seq_length: ${policy.max_total_sequence_length}
28+
prompt_file: "examples/prompts/cot.txt"
29+
system_prompt_file: "examples/prompts/gsm8k.txt"
30+
shuffle: true
31+
num_workers: 1
32+
processor: "math_gdpo_data_processor"
33+
env_name: "math"
34+
dataset_name: "gsm8k"
35+
36+
env:
37+
math:
38+
num_workers: 8
39+
math_verify_impl: "hf_math_verify"
40+
41+
logger:
42+
wandb_enabled: true
43+
wandb:
44+
project: "gdpo-dev"
45+
name: "gdpo-dev-logger"
46+
swanlab:
47+
project: "gdpo-dev"
48+
name: "gdpo-dev-logger"
49+
mlflow:
50+
experiment_name: "gdpo-dev"
51+
run_name: "gdpo-dev-logger"

examples/prompts/gsm8k.txt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
You are a helpful AI assistant.
2+
3+
For every request, you should carefully think through the math problem step by step, then provide the final answer in integer format.
4+
5+
Steps for Each Request:
6+
1. Think: Provide detailed, step-by-step reasoning, calculations, or derivations.
7+
2. Produce Final Answer: After step-by-step reasoning, output the final answer in integer format.
8+
9+
Output Format:
10+
<think>Your thoughts and reasoning</think>
11+
<answer>Final answer in integer format</answer>
12+
13+
Important Notes:
14+
1. You must include your reasoning steps inside <think>.
15+
2. You must always output the Final Answer within <answer> after the reasoning steps is done.
16+
3. You should consistently work through the solution step by step before giving the final answer.
17+
4. The final answer can only be an integer.

examples/run_gdpo_gsm8k.py

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import argparse
16+
import os
17+
import pprint
18+
from collections import defaultdict
19+
from typing import Any, Optional
20+
21+
from omegaconf import OmegaConf
22+
from transformers import PreTrainedTokenizerBase
23+
24+
from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup
25+
from nemo_rl.algorithms.utils import get_tokenizer
26+
from nemo_rl.data import DataConfig
27+
from nemo_rl.data.datasets import AllTaskProcessedDataset, load_response_dataset
28+
from nemo_rl.data.interfaces import (
29+
TaskDataProcessFnCallable,
30+
TaskDataSpec,
31+
)
32+
from nemo_rl.data.processors import math_gdpo_data_processor
33+
from nemo_rl.distributed.ray_actor_environment_registry import (
34+
get_actor_python_env,
35+
)
36+
from nemo_rl.distributed.virtual_cluster import init_ray
37+
from nemo_rl.environments.interfaces import EnvironmentInterface
38+
from nemo_rl.environments.math_environment import MathMultiRewardEnvironment
39+
from nemo_rl.models.generation import configure_generation_config
40+
from nemo_rl.utils.config import load_config, parse_hydra_overrides
41+
from nemo_rl.utils.logger import get_next_experiment_dir
42+
43+
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
44+
45+
46+
def parse_args() -> tuple[argparse.Namespace, list[str]]:
47+
"""Parse command line arguments."""
48+
parser = argparse.ArgumentParser(description="Run GRPO training with configuration")
49+
parser.add_argument(
50+
"--config", type=str, default=None, help="Path to YAML config file"
51+
)
52+
53+
# Parse known args for the script
54+
args, overrides = parser.parse_known_args()
55+
56+
return args, overrides
57+
58+
59+
# ===============================================================================
60+
# Math Data Processor
61+
# ===============================================================================
62+
TokenizerType = PreTrainedTokenizerBase
63+
64+
65+
def setup_data(
66+
tokenizer: TokenizerType,
67+
data_config: DataConfig,
68+
env_configs: dict[str, Any],
69+
seed: int,
70+
) -> tuple[
71+
AllTaskProcessedDataset,
72+
Optional[AllTaskProcessedDataset],
73+
dict[str, EnvironmentInterface],
74+
dict[str, EnvironmentInterface],
75+
]:
76+
print("\n▶ Setting up data...")
77+
math_task_spec = TaskDataSpec(
78+
task_name="math",
79+
prompt_file=data_config["prompt_file"],
80+
system_prompt_file=data_config["system_prompt_file"],
81+
)
82+
83+
# load dataset
84+
data: Any = load_response_dataset(data_config)
85+
task_name = (
86+
data.task_name if hasattr(data, "task_name") else data.task_spec.task_name
87+
)
88+
89+
# data processor
90+
task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = (
91+
defaultdict(lambda: (math_task_spec, math_gdpo_data_processor))
92+
)
93+
task_data_processors[task_name] = (math_task_spec, math_gdpo_data_processor)
94+
95+
# setup math environment
96+
math_env = MathMultiRewardEnvironment.options( # type: ignore # it's wrapped with ray.remote
97+
runtime_env={
98+
"py_executable": get_actor_python_env(
99+
"nemo_rl.environments.math_environment.MathMultiRewardEnvironment"
100+
),
101+
"env_vars": dict(os.environ), # Pass thru all user environment variables
102+
}
103+
).remote(env_configs["math"])
104+
105+
dataset = AllTaskProcessedDataset(
106+
data.dataset,
107+
tokenizer,
108+
math_task_spec,
109+
task_data_processors,
110+
max_seq_length=data_config["max_input_seq_length"],
111+
)
112+
113+
val_dataset: Optional[AllTaskProcessedDataset] = None
114+
if data.val_dataset is not None:
115+
val_dataset = AllTaskProcessedDataset(
116+
data.val_dataset,
117+
tokenizer,
118+
math_task_spec,
119+
task_data_processors,
120+
max_seq_length=data_config["max_input_seq_length"],
121+
)
122+
123+
task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: math_env)
124+
task_to_env[task_name] = math_env
125+
return dataset, val_dataset, task_to_env, task_to_env
126+
127+
128+
def main() -> None:
129+
"""Main entry point."""
130+
# Parse arguments
131+
args, overrides = parse_args()
132+
133+
if not args.config:
134+
args.config = os.path.join(
135+
os.path.dirname(__file__), "configs", "gdpo_math_1B.yaml"
136+
)
137+
138+
config = load_config(args.config)
139+
print(f"Loaded configuration from: {args.config}")
140+
141+
if overrides:
142+
print(f"Overrides: {overrides}")
143+
config = parse_hydra_overrides(config, overrides)
144+
145+
config: MasterConfig = OmegaConf.to_container(config, resolve=True)
146+
print("Applied CLI overrides")
147+
148+
# Print config
149+
print("Final config:")
150+
pprint.pprint(config)
151+
152+
# Get the next experiment directory with incremented ID
153+
config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"])
154+
print(f"📊 Using log directory: {config['logger']['log_dir']}")
155+
if config["checkpointing"]["enabled"]:
156+
print(
157+
f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}"
158+
)
159+
160+
init_ray()
161+
162+
# setup tokenizer
163+
tokenizer = get_tokenizer(config["policy"]["tokenizer"])
164+
assert config["policy"]["generation"] is not None, (
165+
"A generation config is required for GRPO"
166+
)
167+
config["policy"]["generation"] = configure_generation_config(
168+
config["policy"]["generation"], tokenizer
169+
)
170+
171+
# setup data
172+
(
173+
dataset,
174+
val_dataset,
175+
task_to_env,
176+
val_task_to_env,
177+
) = setup_data(tokenizer, config["data"], config["env"], config["grpo"]["seed"])
178+
179+
(
180+
policy,
181+
policy_generation,
182+
cluster,
183+
dataloader,
184+
val_dataloader,
185+
loss_fn,
186+
logger,
187+
checkpointer,
188+
grpo_state,
189+
master_config,
190+
) = setup(config, tokenizer, dataset, val_dataset)
191+
192+
# Check if async mode is enabled
193+
if "async_grpo" in config["grpo"] and config["grpo"]["async_grpo"]["enabled"]:
194+
# Async GRPO does not support dynamic sampling, reward scaling, or reward shaping (DAPO features)
195+
unsupported_features = [
196+
"use_dynamic_sampling",
197+
"reward_scaling",
198+
"reward_shaping",
199+
]
200+
201+
for feature in unsupported_features:
202+
if feature not in config["grpo"]:
203+
continue
204+
205+
if feature == "use_dynamic_sampling":
206+
if config["grpo"][feature]:
207+
raise NotImplementedError(
208+
f"{feature} is not supported with async GRPO"
209+
)
210+
else:
211+
if config["grpo"][feature]["enabled"]:
212+
raise NotImplementedError(
213+
f"{feature} is not supported with async GRPO"
214+
)
215+
216+
from nemo_rl.algorithms.grpo import async_grpo_train
217+
218+
print("🚀 Running async GRPO training")
219+
220+
async_config = config["grpo"]["async_grpo"]
221+
# Run async GRPO training
222+
async_grpo_train(
223+
policy=policy,
224+
policy_generation=policy_generation,
225+
dataloader=dataloader,
226+
val_dataloader=val_dataloader,
227+
tokenizer=tokenizer,
228+
loss_fn=loss_fn,
229+
task_to_env=task_to_env,
230+
val_task_to_env=val_task_to_env,
231+
logger=logger,
232+
checkpointer=checkpointer,
233+
grpo_save_state=grpo_state,
234+
master_config=master_config,
235+
max_trajectory_age_steps=async_config["max_trajectory_age_steps"],
236+
)
237+
else:
238+
print("🚀 Running synchronous GRPO training")
239+
240+
# Run standard GRPO training
241+
grpo_train(
242+
policy,
243+
policy_generation,
244+
dataloader,
245+
val_dataloader,
246+
tokenizer,
247+
loss_fn,
248+
task_to_env,
249+
val_task_to_env,
250+
logger,
251+
checkpointer,
252+
grpo_state,
253+
master_config,
254+
)
255+
256+
257+
if __name__ == "__main__":
258+
main()

0 commit comments

Comments
 (0)