Skip to content

Commit 56a6225

Browse files
authored
feat: add support for nemotron-nas with custom plan. (#1180)
Signed-off-by: Jonas Yang <[email protected]>
1 parent 7aa7071 commit 56a6225

File tree

6 files changed

+303
-4
lines changed

6 files changed

+303
-4
lines changed
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# GRPO Algorithm Configuration
2+
grpo:
3+
num_prompts_per_step: 128
4+
num_generations_per_prompt: 16
5+
max_rollout_turns: 1 # for multi-turn rollouts. Math Environments just have 1 turn (answering the question)
6+
max_num_epochs: 1
7+
max_num_steps: 1000000
8+
normalize_rewards: true
9+
use_leave_one_out_baseline: true
10+
val_period: 10
11+
val_at_start: false
12+
overlong_filtering: false
13+
max_val_samples: 256
14+
val_batch_size: 256
15+
seed: 42
16+
async_grpo:
17+
enabled: false
18+
max_trajectory_age_steps: 1
19+
20+
loss_fn:
21+
reference_policy_kl_penalty: 0.01
22+
ratio_clip_min: 0.2
23+
ratio_clip_max: 0.2
24+
ratio_clip_c: null
25+
# (default off) loss formulation improvements (docs/guides/grpo.md#loss)
26+
use_on_policy_kl_approximation: false
27+
use_importance_sampling_correction: false
28+
sequence_level_importance_ratios: false
29+
token_level_loss: true
30+
31+
checkpointing:
32+
enabled: true
33+
checkpoint_dir: "results/grpo"
34+
metric_name: "val_reward"
35+
higher_is_better: true
36+
keep_top_k: 3
37+
save_period: 10
38+
checkpoint_must_save_by: null
39+
model_save_format: "safetensors"
40+
save_consolidated: false
41+
42+
policy:
43+
model_name: "nvidia/Llama-3_3-Nemotron-Super-49B-v1_5"
44+
tokenizer:
45+
name: "nvidia/Llama-3_3-Nemotron-Super-49B-v1_5"
46+
max_total_sequence_length: 1024
47+
precision: "bfloat16"
48+
train_global_batch_size: 128
49+
train_micro_batch_size: 4
50+
logprob_batch_size: 4
51+
logprob_chunk_size: null
52+
53+
dtensor_cfg:
54+
_v2: true
55+
activation_checkpointing: true
56+
context_parallel_size: 1
57+
cpu_offload: false
58+
enabled: true
59+
sequence_parallel: false
60+
tensor_parallel_size: 8
61+
custom_parallel_plan: examples.configs.recipes.llm.llama_nemotron_super_49b_custom_plan.custom_parallel_plan
62+
63+
megatron_cfg:
64+
enabled: false
65+
66+
# See docs/design-docs/sequence-packing-and-dynamic-batching.md
67+
# for more details on dynamic batching and sequence packing.
68+
dynamic_batching:
69+
enabled: True
70+
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
71+
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
72+
sequence_length_round: 64
73+
74+
sequence_packing:
75+
enabled: False
76+
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
77+
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
78+
algorithm: "modified_first_fit_decreasing"
79+
sequence_length_round: 64
80+
81+
make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
82+
max_grad_norm: 1.0
83+
84+
optimizer:
85+
name: "torch.optim.AdamW"
86+
kwargs:
87+
lr: 3.0e-7
88+
weight_decay: 0.01
89+
betas: [0.9, 0.999]
90+
eps: 1e-8
91+
92+
scheduler:
93+
- name: "torch.optim.lr_scheduler.LinearLR"
94+
kwargs:
95+
start_factor: 0.1
96+
end_factor: 1.0
97+
# The scheduler iteration is per GPRO step and is decoupled with the optimizer step (may be >=1 per GPRO step)
98+
total_iters: 13
99+
- name: "torch.optim.lr_scheduler.ConstantLR"
100+
kwargs:
101+
factor: 1.0
102+
total_iters: 10000000000
103+
- milestones: [13]
104+
105+
generation:
106+
backend: "vllm"
107+
max_new_tokens: ${policy.max_total_sequence_length}
108+
temperature: 1.0
109+
top_p: 1.0
110+
top_k: null
111+
stop_token_ids: null
112+
stop_strings: null
113+
vllm_cfg:
114+
async_engine: false
115+
precision: ${policy.precision}
116+
tensor_parallel_size: 4
117+
pipeline_parallel_size: 1
118+
expert_parallel_size: 1 # When EP > 1, EP must be a multiple of TP since vLLM's EP = DP * TP
119+
gpu_memory_utilization: 0.6
120+
max_model_len: ${policy.max_total_sequence_length}
121+
# when enforce_eager is False, it is optional to set ++policy.generation.vllm_kwargs.compilation_config.use_inductor=False for better accuracy,
122+
# with the flag, vllm will use the custom CUDA kernels instead of the Triton kernels generated by torch.compile
123+
# for more details, see convergence issue https://github.com/NVIDIA-NeMo/RL/issues/998
124+
enforce_eager: False
125+
use_deep_gemm: False
126+
num_last_layers_in_bf16: 0
127+
num_first_layers_in_bf16: 0
128+
vllm_kwargs: {}
129+
colocated:
130+
# true: generation shares training GPUs
131+
# false: uses dedicated generation resources
132+
enabled: true
133+
# only relevant when enabled is false
134+
resources:
135+
gpus_per_node: null # Decides num gpus to be dedicated to generation when there is one node in the cluster i.e cluster.num_nodes == 1
136+
num_nodes: null # Decides number of nodes to be dedicated to generation
137+
138+
data:
139+
max_input_seq_length: ${policy.max_total_sequence_length} # upper bound, real truncation occurs at vllm.max_model_len
140+
prompt_file: "examples/prompts/cot.txt"
141+
system_prompt_file: null
142+
shuffle: true
143+
144+
dataset_name: "OpenMathInstruct-2"
145+
# You can use custom response datasets for training and validation. For example:
146+
# data:
147+
# dataset_name: ResponseDataset
148+
# train_data_path: <PathToTrainingDataset> # e.g., /path/to/local/dataset.jsonl or hf_org/hf_dataset_name (HuggingFace)
149+
# val_data_path: <PathToValidationDataset>
150+
# input_key: <QuestionKey>, default is "input"
151+
# output_key: <AnswerKey>, default is "output"
152+
# train_split: <TrainSplit>, default is None # used for HuggingFace datasets
153+
# val_split: <ValSplit>, default is None # used for HuggingFace datasets
154+
# See https://github.com/NVIDIA-NeMo/RL/blob/main/docs/guides/grpo.md#datasets for more details.
155+
156+
env:
157+
math:
158+
num_workers: 8
159+
160+
logger:
161+
log_dir: "logs" # Base directory for all logs
162+
num_val_samples_to_print: 0
163+
wandb_enabled: true # Make sure you do a ``wandb login [Your API key]'' before running
164+
tensorboard_enabled: false
165+
mlflow_enabled: false
166+
monitor_gpus: false # If true, will monitor GPU usage and log to wandb and/or tensorboard
167+
wandb:
168+
project: "grpo-nemotron-super-49b"
169+
name: "grpo-${data.dataset_name}-nemotron-super-49b-tp${policy.dtensor_cfg.tensor_parallel_size}"
170+
tensorboard: {}
171+
mlflow:
172+
experiment_name: "sft-dev"
173+
run_name: "grpo-nemotron-super-49b"
174+
gpu_monitoring:
175+
collection_interval: 10 # How often to collect GPU usage metrics (in seconds)
176+
flush_interval: 10 # How often to flush GPU usage metrics to the loggers (in seconds)
177+
178+
cluster:
179+
gpus_per_node: 8
180+
num_nodes: 4
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from torch.distributed.tensor.parallel import (
16+
ColwiseParallel,
17+
ParallelStyle,
18+
PrepareModuleInput,
19+
PrepareModuleOutput,
20+
RowwiseParallel,
21+
)
22+
from torch.distributed.tensor.placement_types import Replicate, Shard
23+
24+
custom_parallel_plan: dict[str, ParallelStyle] = {
25+
"model.layers.*.self_attn": PrepareModuleInput(
26+
input_kwarg_layouts={"attention_mask": Replicate()},
27+
desired_input_kwarg_layouts={"attention_mask": Replicate()},
28+
),
29+
"model.embed_tokens": RowwiseParallel(
30+
input_layouts=Replicate(), output_layouts=Replicate(), use_local_output=True
31+
),
32+
"model.layers.*.self_attn.q_proj": ColwiseParallel(use_local_output=False),
33+
"model.layers.*.self_attn.k_proj": ColwiseParallel(use_local_output=False),
34+
"model.layers.*.self_attn.v_proj": ColwiseParallel(use_local_output=False),
35+
"model.layers.*.self_attn.o_proj": RowwiseParallel(
36+
output_layouts=Replicate(), use_local_output=True
37+
),
38+
"model.layers.*.self_attn.rotary_emb": PrepareModuleOutput(
39+
output_layouts=(Replicate(), Replicate()),
40+
desired_output_layouts=(Replicate(), Replicate()),
41+
use_local_output=False,
42+
),
43+
"model.layers.*.mlp.up_proj": ColwiseParallel(),
44+
"model.layers.*.mlp.gate_proj": ColwiseParallel(),
45+
"model.layers.*.mlp.down_proj": RowwiseParallel(
46+
output_layouts=Replicate(), use_local_output=True
47+
),
48+
"lm_head": ColwiseParallel(output_layouts=Shard(-1), use_local_output=False),
49+
}

nemo_rl/models/policy/dtensor_policy_worker_v2.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
)
4040
from nemo_automodel.components.distributed.parallelizer import (
4141
fsdp2_strategy_parallelize,
42-
unshard_fsdp2_model,
4342
)
4443
from nemo_automodel.components.distributed.tensor_utils import (
4544
get_cpu_state_dict,
@@ -181,6 +180,10 @@ def __init__(
181180
else None,
182181
)
183182

183+
self.allow_flash_attn_args = self.check_model_allow_flash_attn_args(
184+
model_config
185+
)
186+
184187
self._is_reward_model = (
185188
"reward_model_cfg" in self.cfg and self.cfg["reward_model_cfg"]["enabled"]
186189
)
@@ -467,6 +470,17 @@ def init_collective(self, ip: str, port: int, world_size: int) -> None:
467470
def is_alive(self) -> bool:
468471
return True
469472

473+
def check_model_allow_flash_attn_args(self, model_config) -> bool:
474+
# Some models doesn't support flash_attn_kwargs
475+
# Check nemotron nas.
476+
if (
477+
model_config.architectures[0] == "DeciLMForCausalLM"
478+
and model_config.model_type == "nemotron-nas"
479+
):
480+
return False
481+
482+
return True
483+
470484
def reset_peak_memory_stats(self) -> None:
471485
torch.cuda.reset_peak_memory_stats()
472486

@@ -686,6 +700,12 @@ def train(
686700
if len(vlm_kwargs) > 0:
687701
del model_args["flash_attn_kwargs"]
688702

703+
if (
704+
not self.allow_flash_attn_args
705+
and "flash_attn_kwargs" in model_args
706+
):
707+
del model_args["flash_attn_kwargs"]
708+
689709
outputs = self.model(**model_args)
690710

691711
# Get logprobs
@@ -879,7 +899,7 @@ def get_logprobs(
879899
all_log_probs = []
880900
self.model.eval()
881901

882-
with unshard_fsdp2_model(self.model), torch.no_grad():
902+
with torch.no_grad():
883903
data.to("cuda")
884904
dummy_iterator = iter([])
885905
if self.cfg["dynamic_batching"]["enabled"]:
@@ -997,6 +1017,12 @@ def get_logprobs(
9971017
if len(vlm_kwargs) > 0:
9981018
del model_args["flash_attn_kwargs"]
9991019

1020+
if (
1021+
not self.allow_flash_attn_args
1022+
and "flash_attn_kwargs" in model_args
1023+
):
1024+
del model_args["flash_attn_kwargs"]
1025+
10001026
outputs = self.model(**model_args)
10011027

10021028
logits = outputs.logits
@@ -1158,7 +1184,7 @@ def score(self, data: BatchedDataDict) -> BatchedDataDict[ScoreOutputSpec]:
11581184
)
11591185
self.model.eval()
11601186
print("Begin to batch datas")
1161-
with unshard_fsdp2_model(self.model), torch.no_grad():
1187+
with torch.no_grad():
11621188
data.to("cuda")
11631189
dummy_iterator = iter([])
11641190
if self.cfg["dynamic_batching"]["enabled"]:
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#!/bin/bash
2+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
3+
source $SCRIPT_DIR/common.env
4+
5+
# ===== BEGIN CONFIG =====
6+
NUM_NODES=4
7+
STEPS_PER_RUN=2 # 40min: step_time: [1341, 801]
8+
MAX_STEPS=2
9+
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
10+
NUM_MINUTES=30
11+
# ===== END CONFIG =====
12+
13+
exit_if_max_steps_reached
14+
15+
# Run the experiment
16+
cd $PROJECT_ROOT
17+
uv run examples/run_grpo_math.py \
18+
--config $CONFIG_PATH \
19+
grpo.max_num_steps=$MAX_STEPS \
20+
logger.log_dir=$LOG_DIR \
21+
logger.wandb_enabled=True \
22+
logger.wandb.project=nemo-rl \
23+
logger.wandb.name=$EXP_NAME \
24+
logger.monitor_gpus=True \
25+
logger.tensorboard_enabled=True \
26+
checkpointing.enabled=True \
27+
checkpointing.checkpoint_dir=$CKPT_DIR \
28+
$@ \
29+
2>&1 | tee $RUN_LOG
30+
31+
# Convert tensorboard logs to json
32+
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
33+
34+
# Only run metrics if the target step is reached
35+
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
36+
uv run tests/check_metrics.py $JSON_METRICS \
37+
'mean(data["train/token_mult_prob_error"]) < 1.1' \
38+
'data["train/token_mult_prob_error"]["2"] < 1.1' \
39+
'mean(data["timing/train/policy_training"]) < 280' \
40+
'mean(data["ray/node.0.gpu.0.mem_gb"]) < 75'
41+
fi

tests/test_suites/nightly.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8.sh
4040
# Non-colocated
4141
tests/test_suites/llm/grpo-llama3.1-8b-instruct-2n8g-fsdp2tp1-noncolocated.sh
4242

43+
# Nemotron Super 49B
44+
tests/test_suites/llm/grpo-math-llama-nemotron-super-49b-v.5-4n8g-fsdp2tp8.sh
45+
4346
#######
4447
# SFT #
4548
#######

0 commit comments

Comments
 (0)