Skip to content

Commit a426896

Browse files
vadam5ashors1rootterrykong
authored
feat: Megatron LoRA GRPO w/ Weight Merging (#1889)
Signed-off-by: Anna Shors <ashors@nvidia.com> Signed-off-by: Virginia Wu <vadams@nvidia.com> Signed-off-by: Virginia Wu <78445382+vadam5@users.noreply.github.com> Signed-off-by: Terry Kong <terryk@nvidia.com> Co-authored-by: Anna Shors <ashors@nvidia.com> Co-authored-by: root <root@pool0-00689.cm.cluster> Co-authored-by: Terry Kong <terryk@nvidia.com>
1 parent d62702f commit a426896

16 files changed

+427
-23
lines changed

examples/configs/distillation_math_megatron.yaml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,19 @@ policy: &POLICY_BASE
6262
moe_enable_deepep: false
6363
moe_token_dispatcher_type: "allgather"
6464
moe_shared_expert_overlap: false
65-
65+
peft:
66+
enabled: false
67+
target_modules: []
68+
exclude_modules: []
69+
dim: 8
70+
alpha: 32
71+
dropout: 0.0
72+
dropout_position: "post"
73+
lora_A_init_method: "xavier"
74+
lora_B_init_method: "zero"
75+
a2a_experimental: false
76+
lora_dtype: null
77+
6678
optimizer:
6779
optimizer: "adam"
6880
lr: 2.00001e-5

examples/configs/grpo_math_1B.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,19 @@ policy:
147147
moe_token_dispatcher_type: "allgather"
148148
moe_shared_expert_overlap: false
149149

150+
peft:
151+
enabled: false
152+
target_modules: []
153+
exclude_modules: []
154+
dim: 8
155+
alpha: 32
156+
dropout: 0.0
157+
dropout_position: "post"
158+
lora_A_init_method: "xavier"
159+
lora_B_init_method: "zero"
160+
a2a_experimental: false
161+
lora_dtype: None
162+
150163
optimizer:
151164
optimizer: "adam"
152165
lr: 5.0e-6

examples/configs/grpo_math_1B_megatron.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,19 @@ policy:
100100
moe_shared_expert_overlap: false
101101
#gives ~20% training perf speedup with sequence packing
102102
apply_rope_fusion: True
103+
104+
peft:
105+
enabled: false
106+
target_modules: []
107+
exclude_modules: []
108+
dim: 8
109+
alpha: 32
110+
dropout: 0.0
111+
dropout_position: "post"
112+
lora_A_init_method: "xavier"
113+
lora_B_init_method: "zero"
114+
a2a_experimental: false
115+
lora_dtype: null
103116

104117
optimizer:
105118
optimizer: "adam"
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
defaults: ../../grpo_math_1B.yaml
2+
grpo:
3+
num_prompts_per_step: 2
4+
num_generations_per_prompt: 8
5+
checkpointing:
6+
checkpoint_dir: results/grpo-nanov3-30BA3B-2n8g-megatron-lora
7+
policy:
8+
model_name: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-Base-BF16
9+
tokenizer:
10+
name: nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16
11+
train_global_batch_size: 16
12+
train_micro_batch_size: 1
13+
logprob_batch_size: 1
14+
max_total_sequence_length: 2048
15+
dtensor_cfg:
16+
enabled: false
17+
megatron_cfg:
18+
enabled: true
19+
peft:
20+
enabled: true
21+
dim: 128
22+
alpha: 512
23+
exclude_modules: ['*out_proj*'] # Exclude all out_proj modules. When NemotronHMamba2Mixer uses cuda_kernels_forward, out_proj LoRA has no gradient.
24+
sequence_packing:
25+
enabled: false
26+
generation:
27+
vllm_cfg:
28+
tensor_parallel_size: 4
29+
gpu_memory_utilization: 0.7
30+
logger:
31+
wandb_enabled: true
32+
tensorboard_enabled: true
33+
wandb:
34+
project: nemo-rl
35+
name: grpo-nanov3-30BA3B-2n8g-megatron-lora
36+
cluster:
37+
gpus_per_node: 8
38+
num_nodes: 2
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
defaults: ../../grpo_math_1B.yaml
2+
grpo:
3+
val_at_start: true
4+
checkpointing:
5+
enabled: false
6+
checkpoint_dir: results/grpo-qwen3-8b-base-1n8g-megatron-lora
7+
policy:
8+
model_name: Qwen/Qwen3-8B-Base
9+
max_total_sequence_length: 2048
10+
dtensor_cfg:
11+
enabled: false
12+
megatron_cfg:
13+
enabled: true
14+
peft:
15+
enabled: true
16+
dim: 128
17+
alpha: 128
18+
scheduler:
19+
lr_warmup_iters: 50
20+
21+
sequence_packing:
22+
enabled: false
23+
logger:
24+
log_dir: logs/grpo-qwen3-8b-base-1n8g-megatron-lora
25+
wandb_enabled: true
26+
tensorboard_enabled: true
27+
wandb:
28+
project: nemo-rl
29+
name: grpo-qwen3-8b-base-1n8g-megatron-lora
30+
cluster:
31+
gpus_per_node: 8

examples/configs/sft_openmathinstruct2_megatron.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,18 @@ policy:
9696
moe_enable_deepep: false
9797
moe_token_dispatcher_type: "allgather"
9898
moe_shared_expert_overlap: false
99+
peft:
100+
enabled: false
101+
target_modules: []
102+
exclude_modules: []
103+
dim: 8
104+
alpha: 32
105+
dropout: 0.0
106+
dropout_position: "post"
107+
lora_A_init_method: "xavier"
108+
lora_B_init_method: "zero"
109+
a2a_experimental: false
110+
lora_dtype: null
99111

100112
env_vars:
101113
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False"

examples/configs/vlm_grpo_3B_megatron.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,18 @@ policy:
158158
moe_enable_deepep: false
159159
moe_token_dispatcher_type: "allgather"
160160
moe_shared_expert_overlap: false
161+
peft:
162+
enabled: false
163+
target_modules: []
164+
exclude_modules: []
165+
dim: 8
166+
alpha: 32
167+
dropout: 0.0
168+
dropout_position: "post"
169+
lora_A_init_method: "xavier"
170+
lora_B_init_method: "zero"
171+
a2a_experimental: false
172+
lora_dtype: null
161173
optimizer:
162174
optimizer: adam
163175
lr: 2.0e-07

nemo_rl/models/megatron/setup.py

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -684,10 +684,6 @@ def setup_model_and_optimizer(
684684

685685
mixed_precision_wrapper = Float16Module
686686
if policy_cfg["megatron_cfg"]["freeze_moe_router"]:
687-
if use_peft:
688-
raise ValueError(
689-
"Freezing the MOE router is not currently supported when using PEFT"
690-
)
691687

692688
def freeze_moe_router(megatron_model):
693689
if not isinstance(megatron_model, list):
@@ -708,6 +704,14 @@ def freeze_moe_router(megatron_model):
708704

709705
if use_peft:
710706
peft_cfg = policy_cfg["megatron_cfg"].get("peft", {})
707+
if "dim" not in peft_cfg or peft_cfg["dim"] is None:
708+
raise ValueError(
709+
"If megtatron_cfg.peft.enabled is True, dim must be set in peft_cfg"
710+
)
711+
if "alpha" not in peft_cfg or peft_cfg["alpha"] is None:
712+
raise ValueError(
713+
"If megtatron_cfg.peft.enabled is True, alpha must be set in peft_cfg"
714+
)
711715
peft = LoRA(
712716
target_modules=peft_cfg["target_modules"],
713717
exclude_modules=peft_cfg["exclude_modules"],
@@ -722,6 +726,7 @@ def freeze_moe_router(megatron_model):
722726
)
723727
else:
724728
peft = None
729+
725730
megatron_cfg.peft = peft
726731

727732
if megatron_cfg.peft is not None:
@@ -872,22 +877,70 @@ def setup_reference_model_state(
872877
if config["megatron_cfg"].get("freeze_moe_router", False):
873878
ref_mixed_precision_wrapper = MoEFloat16Module
874879

880+
ref_pre_wrap_hooks = []
881+
use_peft = config["megatron_cfg"].get("peft", {}).get("enabled", False)
882+
883+
if use_peft:
884+
peft_cfg = config["megatron_cfg"].get("peft", {})
885+
if "dim" not in peft_cfg or peft_cfg["dim"] is None:
886+
raise ValueError(
887+
"If megtatron_cfg.peft.enabled is True, dim must be set in peft_cfg"
888+
)
889+
if "alpha" not in peft_cfg or peft_cfg["alpha"] is None:
890+
raise ValueError(
891+
"If megtatron_cfg.peft.enabled is True, alpha must be set in peft_cfg"
892+
)
893+
peft = LoRA(
894+
target_modules=peft_cfg["target_modules"],
895+
exclude_modules=peft_cfg["exclude_modules"],
896+
dim=peft_cfg["dim"],
897+
alpha=peft_cfg["alpha"],
898+
dropout=peft_cfg["dropout"],
899+
dropout_position=peft_cfg["dropout_position"],
900+
lora_A_init_method="zero",
901+
lora_B_init_method="zero",
902+
a2a_experimental=peft_cfg["a2a_experimental"],
903+
lora_dtype=peft_cfg["lora_dtype"],
904+
)
905+
else:
906+
peft = None
907+
908+
ref_megatron_cfg.peft = peft
909+
910+
if ref_megatron_cfg.peft is not None:
911+
pre_peft_hook = _create_peft_pre_wrap_hook(ref_megatron_cfg, ref_state)
912+
ref_megatron_cfg.model.register_pre_wrap_hook(pre_peft_hook)
913+
914+
def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
915+
model = pre_peft_hook(model)
916+
return model
917+
918+
ref_pre_wrap_hooks.extend([composed_peft_hook])
919+
875920
reference_model = get_model(
876921
megatron_cfg.model,
877922
megatron_cfg.ddp,
878923
use_torch_fsdp2=megatron_cfg.dist.use_torch_fsdp2,
879924
overlap_param_gather_with_optimizer_step=megatron_cfg.optimizer.overlap_param_gather_with_optimizer_step,
880-
pre_wrap_hook=megatron_cfg.rng.data_parallel_random_init,
925+
data_parallel_random_init=megatron_cfg.rng.data_parallel_random_init,
926+
pre_wrap_hook=ref_pre_wrap_hooks,
881927
mixed_precision_wrapper=ref_mixed_precision_wrapper,
882928
pg_collection=ProcessGroupCollection.use_mpu_process_groups(),
883929
)
884930

931+
should_load_checkpoint = (
932+
ref_checkpoint_config.pretrained_checkpoint is not None
933+
and checkpoint_exists(ref_checkpoint_config.pretrained_checkpoint)
934+
)
935+
936+
if should_load_checkpoint and use_peft:
937+
# The finetune toggle is explicitly set to True in order to avoid loading optimizer and RNG states
938+
# This is switched off here in order to load these states from the checkpoint
939+
ref_megatron_cfg.checkpoint.finetune = False
940+
885941
print("Loading the Reference Model")
886-
reference_state_dict = {}
887942

888-
if ref_checkpoint_config.pretrained_checkpoint is not None and checkpoint_exists(
889-
ref_checkpoint_config.pretrained_checkpoint
890-
):
943+
if should_load_checkpoint:
891944
load_checkpoint(
892945
ref_state,
893946
reference_model,
@@ -896,9 +949,14 @@ def setup_reference_model_state(
896949
checkpointing_context=ref_ckpt_context,
897950
skip_load_to_model_and_opt=HAVE_FSDP2 and megatron_cfg.dist.use_torch_fsdp2,
898951
)
952+
else:
953+
print("Reference model not loaded")
954+
955+
reference_state_dict = {}
956+
957+
if should_load_checkpoint or use_peft:
899958
reference_model = reference_model[0]
900959
reference_model.eval()
901-
902960
# Store reference state dict on CPU
903961
for name, item in reference_model.state_dict().items():
904962
if isinstance(item, torch.Tensor):
@@ -908,8 +966,6 @@ def setup_reference_model_state(
908966
cpu_item = item
909967
reference_state_dict[name] = cpu_item
910968
print("Reference model loaded")
911-
else:
912-
print("Reference model not loaded")
913969

914970
return reference_state_dict
915971

nemo_rl/models/policy/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,24 @@ class RewardModelConfig(TypedDict):
108108
reward_model_type: str
109109

110110

111+
class MegatronPeftConfigDisabled(TypedDict):
112+
enabled: Literal[False]
113+
114+
115+
class MegatronPeftConfig(TypedDict):
116+
enabled: Literal[True]
117+
target_modules: list[str]
118+
exclude_modules: list[str]
119+
dim: int
120+
alpha: int
121+
dropout: float
122+
dropout_position: Literal["pre", "post"]
123+
lora_A_init_method: str
124+
lora_B_init_method: str
125+
a2a_experimental: bool
126+
lora_dtype: str | None
127+
128+
111129
class MegatronOptimizerConfig(TypedDict):
112130
optimizer: str
113131
lr: float
@@ -193,6 +211,7 @@ class MegatronConfig(TypedDict):
193211
moe_token_dispatcher_type: str
194212
# Can be used only with 'alltoall' token dispatcher
195213
moe_shared_expert_overlap: bool
214+
peft: NotRequired[MegatronPeftConfig | MegatronPeftConfigDisabled]
196215
optimizer: MegatronOptimizerConfig
197216
scheduler: MegatronSchedulerConfig
198217
distributed_data_parallel_config: MegatronDDPConfig

nemo_rl/models/policy/workers/megatron_policy_worker.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -541,10 +541,10 @@ def use_reference_model(self):
541541
)
542542
model_state_dict[name] = item
543543

544-
self.model.load_state_dict(self.reference_state_dict, strict=True)
545-
# for name, item in self.reference_state_dict.items():
546-
# if isinstance(item, torch.Tensor):
547-
# self.model.state_dict()[name] = item.detach().to(device="cuda", non_blocking=True, copy=True)
544+
# Swap reference model state_dict to self.model
545+
for k, v in self.model.state_dict().items():
546+
if isinstance(v, torch.Tensor):
547+
v.copy_(self.reference_state_dict[k])
548548

549549
if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1:
550550
gc.collect()
@@ -556,11 +556,9 @@ def use_reference_model(self):
556556

557557
finally:
558558
# Restore original references and device placement
559-
self.model.load_state_dict(model_state_dict, strict=True)
560-
# for name, item in model_state_dict.items():
561-
# if isinstance(item, torch.Tensor):
562-
# item = item.detach().to(device="cuda", non_blocking=True, copy=True)
563-
# self.model.state_dict()[name] = item
559+
for k, v in self.model.state_dict().items():
560+
if isinstance(v, torch.Tensor):
561+
v.copy_(model_state_dict[k])
564562

565563
if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1:
566564
gc.collect()

0 commit comments

Comments
 (0)