@@ -543,6 +543,87 @@ def test_adapter_gradient_flow(lora_config, tmp_path):
543543 )
544544
545545
546+ def _test_adapter_gradient_flow_freeze_lora_with_api (lora_config , tmp_path , rank , size ):
547+ hidden_size = 256
548+
549+ initialize_for_megatron (tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 )
550+ model = _gpt_model_provider (tp_size = size , hidden_size = hidden_size )
551+ prompt_tokens = torch .randint (0 , model .vocab_size , (2 , model .max_sequence_length )).cuda ()
552+ lora_config ["freeze_lora_weights" ] = False
553+ lora_config ["freeze_base_model" ] = False
554+
555+ mtpeft .update_model (model , lora_config )
556+ # Freeze the self_attention layers only
557+ mtpeft .freeze_lora_weights (model , layer_patterns = "*self_attention*" )
558+ model .train ()
559+
560+ # Use a simple forward pass instead for grad check
561+ batch_size = prompt_tokens .shape [0 ]
562+ seq_len = prompt_tokens .shape [- 1 ]
563+ device = prompt_tokens .device
564+
565+ attention_mask = (
566+ torch .triu (torch .ones ((batch_size , seq_len , seq_len ), device = device ), diagonal = 1 )
567+ .bool ()
568+ .view (batch_size , 1 , seq_len , seq_len )
569+ )
570+
571+ output = model (prompt_tokens , position_ids = None , attention_mask = attention_mask )
572+
573+ loss = output .sum ()
574+ loss .backward ()
575+
576+ for name , param in model .named_parameters ():
577+ if "lora" in name and "self_attention" in name :
578+ assert param .grad is None
579+ else :
580+ assert param .grad is not None
581+ assert torch .any (param .grad != 0 ), "weight gradient is all zeros"
582+
583+ for p in model .parameters ():
584+ p .grad = None
585+
586+ mtpeft .freeze_lora_weights (model )
587+ model .train ()
588+
589+ # Use a simple forward pass instead for grad check
590+ batch_size = prompt_tokens .shape [0 ]
591+ seq_len = prompt_tokens .shape [- 1 ]
592+ device = prompt_tokens .device
593+
594+ attention_mask = (
595+ torch .triu (torch .ones ((batch_size , seq_len , seq_len ), device = device ), diagonal = 1 )
596+ .bool ()
597+ .view (batch_size , 1 , seq_len , seq_len )
598+ )
599+
600+ output = model (prompt_tokens , position_ids = None , attention_mask = attention_mask )
601+
602+ loss = output .sum ()
603+ loss .backward ()
604+
605+ for name , param in model .named_parameters ():
606+ if "lora" in name :
607+ assert param .grad is None
608+ else :
609+ assert param .grad is not None
610+ assert torch .any (param .grad != 0 ), "weight gradient is all zeros"
611+
612+
613+ @pytest .mark .parametrize (
614+ "lora_config" ,
615+ [
616+ LARGE_LORA_CFG_RANDOM_INIT_TEST ,
617+ ],
618+ )
619+ def test_adapter_gradient_flow_freeze_lora_with_api (lora_config , tmp_path ):
620+ spawn_multiprocess_job (
621+ size = torch .cuda .device_count (),
622+ job = partial (_test_adapter_gradient_flow_freeze_lora_with_api , lora_config , str (tmp_path )),
623+ backend = "nccl" ,
624+ )
625+
626+
546627def _test_quantize_then_lora (lora_config , tmp_path , rank , size ):
547628 hidden_size = 512
548629 initialize_for_megatron (tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 )
0 commit comments