Skip to content

Commit 8c70ee6

Browse files
authored
fix: checkpoint loading bug in Megatron LoRA GRPO (#2075)
Signed-off-by: Virginia Wu <vadams@nvidia.com>
1 parent 919e373 commit 8c70ee6

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

nemo_rl/models/megatron/setup.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -928,16 +928,14 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
928928
pg_collection=ProcessGroupCollection.use_mpu_process_groups(),
929929
)
930930

931+
# If use_peft, the pretrained checkpoint weights are already loaded inside of the pre_wrap_hook
932+
# so they only need to be loaded here if use_peft is False
931933
should_load_checkpoint = (
932-
ref_checkpoint_config.pretrained_checkpoint is not None
934+
not use_peft
935+
and ref_checkpoint_config.pretrained_checkpoint is not None
933936
and checkpoint_exists(ref_checkpoint_config.pretrained_checkpoint)
934937
)
935938

936-
if should_load_checkpoint and use_peft:
937-
# The finetune toggle is explicitly set to True in order to avoid loading optimizer and RNG states
938-
# This is switched off here in order to load these states from the checkpoint
939-
ref_megatron_cfg.checkpoint.finetune = False
940-
941939
print("Loading the Reference Model")
942940

943941
if should_load_checkpoint:
@@ -949,8 +947,6 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
949947
checkpointing_context=ref_ckpt_context,
950948
skip_load_to_model_and_opt=HAVE_FSDP2 and megatron_cfg.dist.use_torch_fsdp2,
951949
)
952-
else:
953-
print("Reference model not loaded")
954950

955951
reference_state_dict = {}
956952

@@ -966,6 +962,8 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
966962
cpu_item = item
967963
reference_state_dict[name] = cpu_item
968964
print("Reference model loaded")
965+
else:
966+
print("Reference model not loaded")
969967

970968
return reference_state_dict
971969

tests/functional/L1_Functional_Tests_GPU.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora_async
5252
run_test uv run --no-sync bash ./tests/functional/grpo_automodel_lora_non_colocated.sh
5353
run_test uv run --no-sync bash ./tests/functional/grpo_megatron.sh
5454
run_test uv run --no-sync bash ./tests/functional/grpo_megatron_generation.sh
55-
run_test uv run --no-sync bash ./tests/functional/grpo_megatron_lora.sh
56-
run_test uv run --no-sync bash ./tests/functional/grpo_megatron_lora_async.sh
55+
run_test fast uv run --no-sync bash ./tests/functional/grpo_megatron_lora.sh
56+
run_test fast uv run --no-sync bash ./tests/functional/grpo_megatron_lora_async.sh
5757
run_test uv run --no-sync bash ./tests/functional/grpo_multiple_dataloaders.sh
5858
run_test uv run --no-sync bash ./tests/functional/grpo_multiturn.sh
5959
run_test uv run --no-sync bash ./tests/functional/grpo_non_colocated.sh

0 commit comments

Comments
 (0)