-
Notifications
You must be signed in to change notification settings - Fork 316
Expand file tree
/
Copy pathrun_grpo_rm.py
More file actions
216 lines (181 loc) · 6.8 KB
/
run_grpo_rm.py
File metadata and controls
216 lines (181 loc) · 6.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import pprint
from collections import defaultdict
from typing import Any, Optional
from omegaconf import OmegaConf
from transformers import PreTrainedTokenizerBase
from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup
from nemo_rl.algorithms.utils import get_tokenizer
from nemo_rl.data import DataConfig
from nemo_rl.data.datasets import AllTaskProcessedDataset, load_response_dataset
from nemo_rl.data.interfaces import (
TaskDataProcessFnCallable,
TaskDataSpec,
)
from nemo_rl.data.processors import math_hf_data_processor
from nemo_rl.distributed.ray_actor_environment_registry import get_actor_python_env
from nemo_rl.distributed.virtual_cluster import init_ray
from nemo_rl.environments.interfaces import EnvironmentInterface
from nemo_rl.environments.reward_model_environment import RewardModelEnvironment
from nemo_rl.models.generation import configure_generation_config
from nemo_rl.utils.config import load_config, parse_hydra_overrides
from nemo_rl.utils.logger import get_next_experiment_dir
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
def parse_args() -> tuple[argparse.Namespace, list[str]]:
"""Parse command line arguments.
Returns:
Tuple of (parsed_args, overrides) where:
- parsed_args: Namespace object containing parsed arguments
- overrides: List of remaining unparsed arguments (Hydra overrides)
"""
parser = argparse.ArgumentParser(description="Run GRPO training with configuration")
parser.add_argument(
"--config", type=str, default=None, help="Path to YAML config file"
)
# Parse known args for the script
args, overrides = parser.parse_known_args()
return args, overrides
# ===============================================================================
# Math Data Processor
# ===============================================================================
TokenizerType = PreTrainedTokenizerBase
def setup_data(
tokenizer: TokenizerType,
data_config: DataConfig,
env_configs: dict[str, Any],
seed: int,
) -> tuple[
AllTaskProcessedDataset,
Optional[AllTaskProcessedDataset],
dict[str, EnvironmentInterface],
dict[str, EnvironmentInterface],
]:
print("\n▶ Setting up data...")
task_name = "math"
reward_model_task_spec = TaskDataSpec(
task_name=task_name,
prompt_file=data_config["prompt_file"],
system_prompt_file=data_config["system_prompt_file"],
)
# load dataset
data: Any = load_response_dataset(data_config, seed)
# data processor
task_data_processors: dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] = (
defaultdict(lambda: (reward_model_task_spec, math_hf_data_processor))
)
task_data_processors[task_name] = (reward_model_task_spec, math_hf_data_processor)
reward_model_env = RewardModelEnvironment.options( # type: ignore # it's wrapped with ray.remote
runtime_env={
"py_executable": get_actor_python_env(
"nemo_rl.environments.reward_model_environment.RewardModelEnvironment"
),
"env_vars": dict(os.environ), # Pass thru all user environment variables
}
).remote(env_configs["reward_model"])
dataset = AllTaskProcessedDataset(
data.formatted_ds["train"],
tokenizer,
reward_model_task_spec,
task_data_processors,
max_seq_length=data_config["max_input_seq_length"],
)
val_dataset: Optional[AllTaskProcessedDataset] = None
if data.formatted_ds["validation"]:
val_dataset = AllTaskProcessedDataset(
data.formatted_ds["validation"],
tokenizer,
reward_model_task_spec,
task_data_processors,
max_seq_length=data_config["max_input_seq_length"],
)
else:
val_dataset = None
task_to_env: dict[str, EnvironmentInterface] = defaultdict(lambda: reward_model_env)
task_to_env[task_name] = reward_model_env
return dataset, val_dataset, task_to_env, task_to_env
def main() -> None:
"""Main entry point."""
# Parse arguments
args, overrides = parse_args()
if not args.config:
args.config = os.path.join(
os.path.dirname(__file__), "configs", "grpo_rm_1B.yaml"
)
config = load_config(args.config)
print(f"Loaded configuration from: {args.config}")
if overrides:
print(f"Overrides: {overrides}")
config = parse_hydra_overrides(config, overrides)
config: MasterConfig = OmegaConf.to_container(config, resolve=True)
print("Applied CLI overrides")
# Print config
print("Final config:")
pprint.pprint(config)
# Get the next experiment directory with incremented ID
config["logger"]["log_dir"] = get_next_experiment_dir(config["logger"]["log_dir"])
print(f"📊 Using log directory: {config['logger']['log_dir']}")
if config["checkpointing"]["enabled"]:
print(
f"📊 Using checkpoint directory: {config['checkpointing']['checkpoint_dir']}"
)
init_ray()
# setup tokenizer
tokenizer = get_tokenizer(config["policy"]["tokenizer"])
assert config["policy"]["generation"] is not None, (
"A generation config is required for GRPO"
)
config["policy"]["generation"] = configure_generation_config(
config["policy"]["generation"], tokenizer
)
# setup data
(
dataset,
val_dataset,
task_to_env,
val_task_to_env,
) = setup_data(tokenizer, config["data"], config["env"], config["grpo"]["seed"])
(
policy,
policy_generation,
cluster,
dataloader,
val_dataloader,
loss_fn,
logger,
checkpointer,
grpo_state,
master_config,
) = setup(config, tokenizer, dataset, val_dataset)
grpo_train(
policy,
policy_generation,
dataloader,
val_dataloader,
tokenizer,
loss_fn,
task_to_env,
val_task_to_env,
logger,
checkpointer,
grpo_state,
master_config,
)
for task_name in val_task_to_env.keys():
env = val_task_to_env[task_name]
env.shutdown.remote()
if __name__ == "__main__":
main()