Skip to content

Commit df31c1b

Browse files
authored
feat: chunked logprob calculation with deferred fp32 cast to help with OOM (#918)
Signed-off-by: Peter Jin <[email protected]>
1 parent 83c6bfc commit df31c1b

File tree

16 files changed

+604
-77
lines changed

16 files changed

+604
-77
lines changed

.github/actions/test-template/action.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ runs:
162162
--shm-size=64g \
163163
--env TRANSFORMERS_OFFLINE=0 \
164164
--env HYDRA_FULL_ERROR=1 \
165+
--env HF_HUB_OFFLINE=1 \
165166
--env HF_HOME=/home/TestData/nemo-rl/hf_home \
166167
--env HF_DATASETS_CACHE=/home/TestData/nemo-rl/hf_datasets_cache \
167168
--env NEMO_RL_REPO_DIR=/opt/nemo-rl \

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[submodule "3rdparty/NeMo"]
22
path = 3rdparty/NeMo-workspace/NeMo
33
url = https://github.com/NVIDIA/NeMo.git
4-
branch = https://github.com/NVIDIA/NeMo/tree/ashors/rl-qwen3-export
4+
branch = pjin/ashors/rl-qwen3-export
55
shallow = true
66
[submodule "3rdparty/Megatron-LM"]
77
path = 3rdparty/Megatron-LM-workspace/Megatron-LM

3rdparty/NeMo-workspace/NeMo

Submodule NeMo updated from aaefedd to 5c42641

examples/configs/grpo_math_1B.yaml

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ policy:
4141
logprob_batch_size: 4
4242
max_total_sequence_length: 512
4343
precision: "bfloat16"
44+
logprob_chunk_size: null
4445

4546
dtensor_cfg:
4647
enabled: true
@@ -53,6 +54,65 @@ policy:
5354

5455
megatron_cfg:
5556
enabled: false
57+
empty_unused_memory_level: 0
58+
activation_checkpointing: false
59+
converter_type: "Qwen2ForCausalLM"
60+
tensor_model_parallel_size: 1
61+
expert_tensor_parallel_size: 1
62+
expert_model_parallel_size: 1
63+
pipeline_model_parallel_size: 1
64+
num_layers_in_first_pipeline_stage: null
65+
num_layers_in_last_pipeline_stage: null
66+
context_parallel_size: 1
67+
pipeline_dtype: ${policy.precision}
68+
sequence_parallel: false
69+
freeze_moe_router: true
70+
moe_router_dtype: "fp64"
71+
moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo
72+
moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo
73+
#gives ~20% training perf speedup with sequence packing
74+
apply_rope_fusion: True
75+
defer_fp32_logits: null
76+
77+
optimizer:
78+
optimizer: "adam"
79+
lr: 5.0e-6
80+
min_lr: 5.0e-7
81+
weight_decay: 0.01
82+
bf16: true
83+
fp16: false
84+
params_dtype: "float32"
85+
86+
#adam
87+
adam_beta1: 0.9
88+
adam_beta2: 0.999
89+
adam_eps: 1e-8
90+
91+
#sgd
92+
sgd_momentum: 0.9
93+
94+
#distributed optimizer
95+
use_distributed_optimizer: true
96+
use_precision_aware_optimizer: true
97+
98+
clip_grad: ${policy.max_grad_norm}
99+
100+
scheduler:
101+
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
102+
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
103+
weight_decay_incr_style: "constant"
104+
lr_decay_style: "constant"
105+
lr_decay_iters: null
106+
lr_warmup_iters: 13
107+
lr_warmup_init: 5.0e-7
108+
109+
distributed_data_parallel_config:
110+
grad_reduce_in_fp32: false
111+
overlap_grad_reduce: true
112+
overlap_param_gather: true
113+
average_in_collective: true
114+
use_custom_fsdp: false
115+
data_parallel_sharding_strategy: "optim_grads_params"
56116

57117
# See docs/design-docs/sequence-packing-and-dynamic-batching.md
58118
# for more details on dynamic batching and sequence packing.

examples/configs/grpo_math_qwen30ba3b_megatron.yaml

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,6 @@ policy:
5656
lr_warmup_iters: 13
5757
lr_warmup_init: 3.0e-8
5858

59-
env_vars:
60-
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False"
61-
6259
generation:
6360
backend: "vllm"
6461
max_new_tokens: ${policy.max_total_sequence_length}
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
checkpointing:
2+
enabled: True
3+
checkpoint_dir: results/grpo-math-qwen3-30ba3b-megatron-tp4-32k
4+
save_period: 3
5+
keep_top_k: 1
6+
metric_name: val_reward
7+
higher_is_better: True
8+
checkpoint_must_save_by: null
9+
10+
grpo:
11+
normalize_rewards: True
12+
use_leave_one_out_baseline: True
13+
max_num_steps: 3
14+
num_prompts_per_step: 64
15+
num_generations_per_prompt: 16
16+
max_rollout_turns: 1
17+
val_period: 3
18+
val_at_start: False
19+
max_val_samples: 256
20+
val_batch_size: 256
21+
seed: 42
22+
23+
loss_fn:
24+
reference_policy_kl_penalty: 0.01
25+
ratio_clip_min: 0.2
26+
ratio_clip_max: 0.2
27+
# (default off) loss formulation improvements (docs/guides/grpo.md#loss)
28+
use_on_policy_kl_approximation: False
29+
use_importance_sampling_correction: False
30+
token_level_loss: True
31+
ratio_clip_c: null
32+
33+
policy:
34+
model_name: "Qwen/Qwen3-30B-A3B"
35+
tokenizer:
36+
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
37+
train_global_batch_size: 512
38+
train_micro_batch_size: 1
39+
generation_batch_size: 32 # Only used when generating using HF backend
40+
logprob_batch_size: 1
41+
max_total_sequence_length: 32768
42+
precision: "bfloat16"
43+
logprob_chunk_size: 2048
44+
45+
dtensor_cfg:
46+
enabled: False
47+
48+
dynamic_batching:
49+
enabled: False
50+
51+
sequence_packing:
52+
enabled: False
53+
54+
max_grad_norm: 1.0
55+
make_sequence_length_divisible_by: ${policy.megatron_cfg.tensor_model_parallel_size}
56+
57+
optimizer: null # remove default FSDP optimizer
58+
59+
scheduler: null # remove default FSDP scheduler
60+
61+
megatron_cfg:
62+
enabled: True
63+
empty_unused_memory_level: 1
64+
converter_type: "LlamaForCausalLM"
65+
tensor_model_parallel_size: 4
66+
pipeline_model_parallel_size: 1
67+
context_parallel_size: 1
68+
expert_tensor_parallel_size: 1
69+
expert_model_parallel_size: 8
70+
sequence_parallel: True
71+
pipeline_dtype: ${policy.precision}
72+
num_layers_in_first_pipeline_stage: null
73+
num_layers_in_last_pipeline_stage: null
74+
freeze_moe_router: True
75+
moe_router_dtype: "fp64"
76+
moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo
77+
moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo
78+
apply_rope_fusion: True
79+
activation_checkpointing: True
80+
defer_fp32_logits: True
81+
82+
optimizer:
83+
optimizer: "adam"
84+
lr: 5.0e-7
85+
min_lr: 5.0e-8
86+
weight_decay: 0.0
87+
bf16: True
88+
fp16: False
89+
params_dtype: "float32"
90+
91+
adam_beta1: 0.9
92+
adam_beta2: 0.999
93+
adam_eps: 1e-8
94+
95+
use_distributed_optimizer: True
96+
use_precision_aware_optimizer: True
97+
98+
clip_grad: ${policy.max_grad_norm}
99+
100+
scheduler:
101+
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
102+
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
103+
weight_decay_incr_style: "constant"
104+
lr_decay_style: "constant"
105+
lr_decay_iters: null
106+
lr_warmup_iters: 2
107+
lr_warmup_init: 5.0e-8
108+
109+
distributed_data_parallel_config:
110+
grad_reduce_in_fp32: False
111+
overlap_grad_reduce: True
112+
overlap_param_gather: True
113+
average_in_collective: True
114+
use_custom_fsdp: False
115+
data_parallel_sharding_strategy: "optim_grads_params"
116+
117+
generation:
118+
backend: "vllm"
119+
max_new_tokens: ${policy.max_total_sequence_length}
120+
temperature: 1.0
121+
top_p: 1.0
122+
top_k: null
123+
stop_token_ids: null
124+
stop_strings: null
125+
vllm_cfg:
126+
async_engine: False
127+
precision: ${policy.precision}
128+
tensor_parallel_size: 4
129+
pipeline_parallel_size: 1
130+
gpu_memory_utilization: 0.6
131+
max_model_len: ${policy.max_total_sequence_length}
132+
# NB(pjin): https://github.com/NVIDIA-NeMo/RL/pull/857
133+
enforce_eager: True
134+
colocated:
135+
enabled: true
136+
resources:
137+
gpus_per_node: null
138+
num_nodes: null
139+
140+
data:
141+
dataset_name: "OpenMathInstruct-2"
142+
shuffle: true
143+
max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len
144+
prompt_file: "examples/prompts/cot.txt"
145+
system_prompt_file: null
146+
147+
env:
148+
math:
149+
num_workers: 8
150+
151+
logger:
152+
log_dir: logs/grpo-math-qwen3-30ba3b-megatron-tp4-32k
153+
num_val_samples_to_print: 0 # Number of validation samples to pretty print on terminal
154+
wandb_enabled: True
155+
tensorboard_enabled: True
156+
mlflow_enabled: False # Disable MLflow logging
157+
monitor_gpus: False # If true, will monitor GPU usage and log to wandb and/or tensorboard
158+
wandb:
159+
project: nemo-rl
160+
name: "grpo-math-qwen3-30ba3b-megatron-tp4-32k"
161+
tensorboard: {}
162+
gpu_monitoring:
163+
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
164+
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)
165+
166+
cluster:
167+
gpus_per_node: 8
168+
num_nodes: 4

nemo_rl/algorithms/loss_functions.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,6 @@ def __call__(
137137
global_normalization_factor=global_valid_toks,
138138
).item()
139139

140-
next_token_logits = next_token_logits.to(torch.float32)
141-
142140
if vocab_parallel_group is not None:
143141
assert vocab_parallel_rank is not None, (
144142
"vocab_parallel_rank must be provided when vocab_parallel_group is provided"
@@ -159,6 +157,7 @@ def __call__(
159157
next_token_logits, data["input_ids"], seq_index=seq_index
160158
)
161159
else:
160+
next_token_logits = next_token_logits.to(torch.float32)
162161
next_token_logits_wo_last = next_token_logits[
163162
:, :-1
164163
] # Remove last position's logits
@@ -327,8 +326,6 @@ def __call__(
327326
mask = token_mask * sample_mask.unsqueeze(-1)
328327
seq_index = data.get("seq_index", None)
329328

330-
next_token_logits = next_token_logits.to(torch.float32)
331-
332329
# Gather the logprobs for the actual next tokens
333330
if vocab_parallel_group is not None:
334331
assert vocab_parallel_rank is not None, (
@@ -351,6 +348,7 @@ def __call__(
351348
)
352349
else:
353350
next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token
351+
next_token_logits = next_token_logits.to(torch.float32)
354352
next_token_logprobs = torch.nn.functional.log_softmax(
355353
next_token_logits, dim=-1
356354
)
@@ -583,7 +581,6 @@ def _dpo_loss(
583581
sample_mask = data["sample_mask"]
584582
seq_index = data.get("seq_index", None)
585583

586-
next_token_logits = next_token_logits.to(torch.float32)
587584
if vocab_parallel_group is not None:
588585
assert vocab_parallel_rank is not None, (
589586
"vocab_parallel_rank must be provided when vocab_parallel_group is provided"
@@ -605,6 +602,7 @@ def _dpo_loss(
605602
)
606603
else:
607604
next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token
605+
next_token_logits = next_token_logits.to(torch.float32)
608606
next_token_logprobs = torch.nn.functional.log_softmax(
609607
next_token_logits, dim=-1
610608
)

0 commit comments

Comments
 (0)