Skip to content

Commit 8816926

Browse files
authored
docs: GRPO documentation and Configuration cleanup (#7)
Signed-off-by: Sahil Jain <[email protected]>
1 parent 759ac40 commit 8816926

File tree

13 files changed

+170
-52
lines changed

13 files changed

+170
-52
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ uv pip install -e '.[dev,test]'
4747
# Use uv run to launch any runs.
4848
# Note that it is recommended to not activate the venv and instead use `uv run` since
4949
# it ensures consistent environment usage across different shells and sessions.
50-
uv run python examples/run_grpo.py
50+
uv run python examples/run_grpo_math.py
5151
```
5252

5353
## Cluster Start

docs/assets/actor-wg-worker-vc.png

608 KB
Loading

docs/design_docs/design_and_philosophy.md

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ Online RL requires coordinating a lot of different pieces of software/models
99

1010
We refer to each of these pieces of software as an **RL Actor**.
1111

12-
[TODO @sahilj Diagram]
13-
1412
Fundamentally, we need to be able to do 4 things between these RL Actors:
1513
- Resource them (provide GPUs/CPUs)
1614
- Isolate them
@@ -32,6 +30,8 @@ We create composable and hackable abstractions for each layer of the tasks above
3230

3331
By creating a common interface for these 4 tasks, **RL algorithm code looks the same from 1 GPU to 1000 GPUs and does not care about the implementation of each RL Actor (Megatron, HF, Grad student with pen and paper)**
3432

33+
![actor-wg-worker-vc](../assets/actor-wg-worker-vc.png)
34+
3535
### {py:class}`RayVirtualCluster <nemo_reinforcer.distributed.virtual_cluster.RayVirtualCluster>`
3636
VirtualCluster provides a basic abstraction on top of Ray Placement Groups that allow you to section off a part of your compute resources for WorkerGroups to run on as though they had their own cluster. They support running just one WorkerGroup on each VirtualCluster, or *colocation*, where multiple WorkerGroups share resources (i.e running policy training(hf) and generation(vllm) on the same GPUs in-turn).
3737

@@ -84,12 +84,29 @@ class RayWorkerGroup:
8484
- Support for tied worker groups where multiple workers process the same data
8585
"""
8686
```
87-
[TODO @sahilj Diagram]
88-
87+
`RayWorkerGroup` provides functions like `run_all_workers_single_data` and `run_all_workers_multiple_data` to control and communicate to individual worker processes.
8988

9089

9190
### Single-Controller & Execution Diagram
92-
93-
## Walking through an implementation of GRPO
94-
95-
91+
We control the RL Actors using a single-process head controller. Using the aforementioned abstractions, this allows us to represent the main loop of GRPO as though we were working on 1 GPU
92+
```python
93+
# data processing/transformations between each step omitted
94+
def grpo_train(
95+
policy: PolicyInterface,
96+
policy_generation: GenerationInterface,
97+
environment: EnvironmentInterface,
98+
dataloader: Iterable[BatchedDataDict[DatumSpec]],
99+
):
100+
loss_fn = GRPOLossFn()
101+
for batch in dataloader:
102+
batch.repeat_interleave(num_generations_per_prompt) # repeat for GRPO
103+
generations = policy_generation.generate(batch)
104+
rewards = environment.step(generations)
105+
106+
logprobs = policy.get_logprobs(generations)
107+
reference_logprobs = policy.get_reference_logprobs(generations)
108+
109+
training_data = calculate_grpo_trainnig_data(generations, logprobs, reference_logprobs, rewards)
110+
policy.train(generations, logprobs, reference_logprobs, GRPOLossFn)
111+
```
112+
For a real implementation of grpo (with valiation, checkpointing, memory movement, and the omitted data processing steps), see [grpo_train](../../nemo_reinforcer/algorithms/grpo.py)

docs/guides/grpo.md

Lines changed: 92 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,93 @@
1-
# GRPO
1+
# An in-depth walkthrough of GRPO in Reinforcer
22

3-
placeholder TBD
3+
## Quickstart: Launch a GRPO Run
4+
5+
If you want to get running quickly, the script [examples/run_grpo_math.py](../../examples/run_grpo_math.py) has an example implementation of using GRPO to train a model on math problems. This script can either be launched locally or via Slurm. For details on how to set up Ray and launch a job using Slurm, refer to the [cluster documentation](../cluster.md).
6+
7+
We recommend launching the job using `uv`:
8+
```bash
9+
uv run examples/run_grpo_math.py --config <PATH TO YAML CONFIG> {overrides}
10+
```
11+
If not specified, `config` will default to [examples/configs/grpo.yaml](../../examples/configs/grpo.yaml)
12+
13+
## Now, for the details:
14+
15+
In this guide, we'll walk through we handle
16+
* Data
17+
* Model training
18+
* Fast generation
19+
* Overall Resource Flow
20+
21+
### Data
22+
We support training with multiple RL "Environments" at the same time.
23+
24+
An [Environment](../../nemo_reinforcer/environments/interfaces.py) is an object that accepts a state/action history and returns an update state and rewards for the step. They run as Ray Remote Actors. Example [MathEnvironment](../../nemo_reinforcer/environments/math_environment.py).
25+
26+
To support this, we need to know:
27+
* What environments you have
28+
* Which data should go to which environments
29+
* How to prepare the data from your dataset into a form we can use
30+
31+
#### Common Data Format
32+
We define a [DatumSpec](../../nemo_reinforcer/data/interfaces.py) that holds all relevant information for each training example:
33+
```python
34+
class DatumSpec(TypedDict):
35+
message_log: LLMMessageLogType
36+
length: int # total (concatenated) length of the message tensors
37+
extra_env_info: Dict[str, Any] # anything your environment requires goes here, for example the 'answer' of a math problem
38+
loss_multiplier: float # multiplier for the loss for this datum. 0 to mask out (say the sample is invalid)
39+
idx: int
40+
task_name: Optional[str] = "default"
41+
__extra__: Any # This allows additional fields of any type
42+
```
43+
44+
#### Data Processors
45+
We name all distinct "environments your model wants to optimize against" "tasks". So you might define a "math" task or a "code" task.
46+
For each task, you should provide a data processor that reads from your dataset and returns a [DatumSpec](../../nemo_reinforcer/data/interfaces.py)
47+
48+
```python
49+
def my_data_processor(
50+
datum_dict: Dict[str, Any], # loaded directly from your dataset (i.e. single line of jsonl data)
51+
task_data_spec: TaskDataSpec,
52+
tokenizer,
53+
max_seq_length: int,
54+
idx: int,
55+
) -> DatumSpec:
56+
```
57+
We have an example of this as `math_data_processor` in [run_grpo_math.py](../../examples/run_grpo_math.py)
58+
59+
#### Putting it all together:
60+
GRPO expects datasets to have the following form:
61+
```json
62+
{"task_name": "math", <actual data>}
63+
```
64+
Then, you can set data up as such:
65+
```python
66+
base_dataset = load_dataset("json", data_files=data_config["dataset_name"])["train"]
67+
tokenizer = AutoTokenizer.from_pretrained(policy_config["model_name"])
68+
69+
task_data_processors = defaultdict(lambda: (math_task_spec, math_data_processor))
70+
task_data_processors["math"] = (math_task_spec, math_data_processor)
71+
72+
math_env = MathEnvironment.remote(env_configs["math"]) # ray remote actor
73+
74+
dataset = AllTaskProcessedDataset(
75+
base_dataset,
76+
tokenizer,
77+
math_task_spec,
78+
task_data_processors,
79+
max_seq_length=data_config["max_input_seq_length"],
80+
)
81+
```
82+
Notice that you provide a mapping of tasks to their processors so the dataset knows what to use when processing samples.
83+
84+
85+
### Policy Model
86+
We define a [PolicyInterface]() that contains everything you need to train a Policy model.
87+
88+
This Policy object holds a [RayWorkerGroup](../../nemo_reinforcer/distributed/worker_groups.py) of SPMD (1 proc/gpu) processes that run HF/MCore, all coordinated by this object so it appears to you like 1 GPU!
89+
90+
### Fast Generation
91+
We support vLLM through the [VllmGeneration](../../nemo_reinforcer/models/generation/vllm.py) class right now.
92+
93+
The function [grpo_train](../../nemo_reinforcer/algorithms/grpo.py) contains the core GRPO training loop.

docs/guides/sft.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ The script [examples/run_sft.py](../../examples/run_sft.py) can be used to launc
66

77
Be sure to launch the job using `uv`. The command to launch an SFT job is as follows:
88
```bash
9-
uv run examples/run_sft.py --config <PATH TO YAML CONFIG>
9+
uv run examples/run_sft.py --config <PATH TO YAML CONFIG> <OVERRIDES>
1010
```
1111
If not specified, `config` will default to [examples/configs/sft.yaml](../../examples/configs/sft.yaml).
1212

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
# GRPO Algorithm Configuration
2-
defaults: "base.yaml"
3-
42
grpo:
53
num_prompts_per_step: 8
64
num_generations_per_prompt: 8
7-
num_steps: 100
5+
max_num_steps: 100
86
normalize_rewards: true
97
use_leave_one_out_baseline: true
108
val_period: 10
@@ -29,16 +27,28 @@ policy:
2927
train_global_batch_size: 32
3028
train_micro_batch_size: 4
3129
generation_batch_size: 32
30+
learning_rate: 5.0e-6
3231
logprob_batch_size: 4
33-
max_total_sequence_length: 1024
32+
max_total_sequence_length: 512
33+
34+
scheduler:
35+
- name: "torch.optim.lr_scheduler.LinearLR"
36+
kwargs:
37+
start_factor: 0.1
38+
end_factor: 1.0
39+
total_iters: 50
40+
- name: "torch.optim.lr_scheduler.ConstantLR"
41+
kwargs:
42+
factor: 1.0
43+
total_iters: 10000000000
44+
- milestones: [50]
3445

3546
generation:
36-
backend: "vllm" # "vllm" or "hf"(to use the hf training framework's generation)
37-
max_new_tokens: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len below
47+
backend: "vllm"
48+
max_new_tokens: ${policy.max_total_sequence_length}
3849
temperature: 1.0
39-
# Don't change since vllm logprobs in V0 runtime are after sampling and in V1 runtime are before sampling.
4050
top_p: 1.0
41-
top_k: null # disable
51+
top_k: null
4252
vllm_cfg:
4353
tensor_parallel_size: 1
4454
gpu_memory_utilization: 0.7
@@ -54,3 +64,17 @@ data:
5464
env:
5565
math:
5666
num_workers: 8
67+
68+
logger:
69+
log_dir: "logs" # Base directory for all logs
70+
num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal
71+
wandb_enabled: false
72+
tensorboard_enabled: false
73+
wandb:
74+
project: "grpo-dev"
75+
name: "grpo-dev-logger"
76+
tensorboard: {}
77+
78+
cluster:
79+
gpus_per_node: 1
80+
num_nodes: 1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
# Base configuration with common settings
1+
# GRPO Algorithm Configuration
2+
defaults: "grpo_math_1B.yaml"
3+
24
policy:
3-
model_name: "meta-llama/Llama-3.2-1B-Instruct"
5+
model_name: "meta-llama/Llama-3.1-8B-Instruct"
46
train_global_batch_size: 32
5-
train_micro_batch_size: 4
7+
train_micro_batch_size: 1
68
generation_batch_size: 32
79
learning_rate: 5.0e-6
8-
logprob_batch_size: 4
9-
max_total_sequence_length: 8192
10+
logprob_batch_size: 2
11+
max_total_sequence_length: 4096
1012

1113
scheduler:
1214
- name: "torch.optim.lr_scheduler.LinearLR"
@@ -28,24 +30,9 @@ policy:
2830
top_k: null
2931
vllm_cfg:
3032
tensor_parallel_size: 1
31-
gpu_memory_utilization: 0.7
33+
gpu_memory_utilization: 0.6
3234
max_model_len: ${policy.max_total_sequence_length}
33-
34-
data:
35-
max_input_seq_length: ${policy.max_total_sequence_length}
36-
prompt_file: "examples/prompts/cot.txt"
37-
system_prompt_file: null
38-
39-
logger:
40-
log_dir: "logs" # Base directory for all logs
41-
num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal
42-
wandb_enabled: false
43-
tensorboard_enabled: false
44-
wandb:
45-
project: "grpo-dev"
46-
name: "grpo-dev-logger"
47-
tensorboard: {}
48-
35+
4936
cluster:
50-
gpus_per_node: 2
37+
gpus_per_node: 8
5138
num_nodes: 1

examples/configs/sft.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SFT Algorithm Configuration
22
sft:
3-
num_steps: 20
3+
max_num_steps: 20
44
val_period: 10
55
val_batches: 8
66
val_global_batch_size: 32

examples/run_grpo_math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def main():
170170
args, overrides = parse_args()
171171

172172
if not args.config:
173-
args.config = os.path.join(os.path.dirname(__file__), "configs", "grpo.yaml")
173+
args.config = os.path.join(os.path.dirname(__file__), "configs", "grpo_math_1B.yaml")
174174

175175
config = load_config(args.config)
176176
print(f"Loaded configuration from: {args.config}")

nemo_reinforcer/algorithms/grpo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
class GRPOConfig(TypedDict):
7070
num_prompts_per_step: int
7171
num_generations_per_prompt: int
72-
num_steps: int
72+
max_num_steps: int
7373
normalize_rewards: bool
7474
use_leave_one_out_baseline: bool
7575
val_period: int
@@ -445,7 +445,7 @@ def grpo_train(
445445

446446
# Run grpo training (single-turn)
447447
for batch in dataloader:
448-
print(f"\n{'=' * 25} Step {step + 1}/{len(dataloader)} {'=' * 25}")
448+
print(f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}")
449449

450450
with timer.time("total_step_time"):
451451
# Prepare batch
@@ -654,7 +654,7 @@ def grpo_train(
654654

655655
timer.reset()
656656
step += 1
657-
if step >= master_config["grpo"]["num_steps"]:
657+
if step >= master_config["grpo"]["max_num_steps"]:
658658
break
659659

660660

0 commit comments

Comments
 (0)