@@ -684,10 +684,6 @@ def setup_model_and_optimizer(
684684
685685 mixed_precision_wrapper = Float16Module
686686 if policy_cfg ["megatron_cfg" ]["freeze_moe_router" ]:
687- if use_peft :
688- raise ValueError (
689- "Freezing the MOE router is not currently supported when using PEFT"
690- )
691687
692688 def freeze_moe_router (megatron_model ):
693689 if not isinstance (megatron_model , list ):
@@ -708,6 +704,14 @@ def freeze_moe_router(megatron_model):
708704
709705 if use_peft :
710706 peft_cfg = policy_cfg ["megatron_cfg" ].get ("peft" , {})
707+ if "dim" not in peft_cfg or peft_cfg ["dim" ] is None :
708+ raise ValueError (
709+ "If megtatron_cfg.peft.enabled is True, dim must be set in peft_cfg"
710+ )
711+ if "alpha" not in peft_cfg or peft_cfg ["alpha" ] is None :
712+ raise ValueError (
713+ "If megtatron_cfg.peft.enabled is True, alpha must be set in peft_cfg"
714+ )
711715 peft = LoRA (
712716 target_modules = peft_cfg ["target_modules" ],
713717 exclude_modules = peft_cfg ["exclude_modules" ],
@@ -722,6 +726,7 @@ def freeze_moe_router(megatron_model):
722726 )
723727 else :
724728 peft = None
729+
725730 megatron_cfg .peft = peft
726731
727732 if megatron_cfg .peft is not None :
@@ -872,22 +877,70 @@ def setup_reference_model_state(
872877 if config ["megatron_cfg" ].get ("freeze_moe_router" , False ):
873878 ref_mixed_precision_wrapper = MoEFloat16Module
874879
880+ ref_pre_wrap_hooks = []
881+ use_peft = config ["megatron_cfg" ].get ("peft" , {}).get ("enabled" , False )
882+
883+ if use_peft :
884+ peft_cfg = config ["megatron_cfg" ].get ("peft" , {})
885+ if "dim" not in peft_cfg or peft_cfg ["dim" ] is None :
886+ raise ValueError (
887+ "If megtatron_cfg.peft.enabled is True, dim must be set in peft_cfg"
888+ )
889+ if "alpha" not in peft_cfg or peft_cfg ["alpha" ] is None :
890+ raise ValueError (
891+ "If megtatron_cfg.peft.enabled is True, alpha must be set in peft_cfg"
892+ )
893+ peft = LoRA (
894+ target_modules = peft_cfg ["target_modules" ],
895+ exclude_modules = peft_cfg ["exclude_modules" ],
896+ dim = peft_cfg ["dim" ],
897+ alpha = peft_cfg ["alpha" ],
898+ dropout = peft_cfg ["dropout" ],
899+ dropout_position = peft_cfg ["dropout_position" ],
900+ lora_A_init_method = "zero" ,
901+ lora_B_init_method = "zero" ,
902+ a2a_experimental = peft_cfg ["a2a_experimental" ],
903+ lora_dtype = peft_cfg ["lora_dtype" ],
904+ )
905+ else :
906+ peft = None
907+
908+ ref_megatron_cfg .peft = peft
909+
910+ if ref_megatron_cfg .peft is not None :
911+ pre_peft_hook = _create_peft_pre_wrap_hook (ref_megatron_cfg , ref_state )
912+ ref_megatron_cfg .model .register_pre_wrap_hook (pre_peft_hook )
913+
914+ def composed_peft_hook (model : list [MegatronModule ]) -> list [MegatronModule ]:
915+ model = pre_peft_hook (model )
916+ return model
917+
918+ ref_pre_wrap_hooks .extend ([composed_peft_hook ])
919+
875920 reference_model = get_model (
876921 megatron_cfg .model ,
877922 megatron_cfg .ddp ,
878923 use_torch_fsdp2 = megatron_cfg .dist .use_torch_fsdp2 ,
879924 overlap_param_gather_with_optimizer_step = megatron_cfg .optimizer .overlap_param_gather_with_optimizer_step ,
880- pre_wrap_hook = megatron_cfg .rng .data_parallel_random_init ,
925+ data_parallel_random_init = megatron_cfg .rng .data_parallel_random_init ,
926+ pre_wrap_hook = ref_pre_wrap_hooks ,
881927 mixed_precision_wrapper = ref_mixed_precision_wrapper ,
882928 pg_collection = ProcessGroupCollection .use_mpu_process_groups (),
883929 )
884930
931+ should_load_checkpoint = (
932+ ref_checkpoint_config .pretrained_checkpoint is not None
933+ and checkpoint_exists (ref_checkpoint_config .pretrained_checkpoint )
934+ )
935+
936+ if should_load_checkpoint and use_peft :
937+ # The finetune toggle is explicitly set to True in order to avoid loading optimizer and RNG states
938+ # This is switched off here in order to load these states from the checkpoint
939+ ref_megatron_cfg .checkpoint .finetune = False
940+
885941 print ("Loading the Reference Model" )
886- reference_state_dict = {}
887942
888- if ref_checkpoint_config .pretrained_checkpoint is not None and checkpoint_exists (
889- ref_checkpoint_config .pretrained_checkpoint
890- ):
943+ if should_load_checkpoint :
891944 load_checkpoint (
892945 ref_state ,
893946 reference_model ,
@@ -896,9 +949,14 @@ def setup_reference_model_state(
896949 checkpointing_context = ref_ckpt_context ,
897950 skip_load_to_model_and_opt = HAVE_FSDP2 and megatron_cfg .dist .use_torch_fsdp2 ,
898951 )
952+ else :
953+ print ("Reference model not loaded" )
954+
955+ reference_state_dict = {}
956+
957+ if should_load_checkpoint or use_peft :
899958 reference_model = reference_model [0 ]
900959 reference_model .eval ()
901-
902960 # Store reference state dict on CPU
903961 for name , item in reference_model .state_dict ().items ():
904962 if isinstance (item , torch .Tensor ):
@@ -908,8 +966,6 @@ def setup_reference_model_state(
908966 cpu_item = item
909967 reference_state_dict [name ] = cpu_item
910968 print ("Reference model loaded" )
911- else :
912- print ("Reference model not loaded" )
913969
914970 return reference_state_dict
915971
0 commit comments