@@ -928,16 +928,14 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
928928 pg_collection = ProcessGroupCollection .use_mpu_process_groups (),
929929 )
930930
931+ # If use_peft, the pretrained checkpoint weights are already loaded inside of the pre_wrap_hook
932+ # so they only need to be loaded here if use_peft is False
931933 should_load_checkpoint = (
932- ref_checkpoint_config .pretrained_checkpoint is not None
934+ not use_peft
935+ and ref_checkpoint_config .pretrained_checkpoint is not None
933936 and checkpoint_exists (ref_checkpoint_config .pretrained_checkpoint )
934937 )
935938
936- if should_load_checkpoint and use_peft :
937- # The finetune toggle is explicitly set to True in order to avoid loading optimizer and RNG states
938- # This is switched off here in order to load these states from the checkpoint
939- ref_megatron_cfg .checkpoint .finetune = False
940-
941939 print ("Loading the Reference Model" )
942940
943941 if should_load_checkpoint :
@@ -949,8 +947,6 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
949947 checkpointing_context = ref_ckpt_context ,
950948 skip_load_to_model_and_opt = HAVE_FSDP2 and megatron_cfg .dist .use_torch_fsdp2 ,
951949 )
952- else :
953- print ("Reference model not loaded" )
954950
955951 reference_state_dict = {}
956952
@@ -966,6 +962,8 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
966962 cpu_item = item
967963 reference_state_dict [name ] = cpu_item
968964 print ("Reference model loaded" )
965+ else :
966+ print ("Reference model not loaded" )
969967
970968 return reference_state_dict
971969
0 commit comments