1+ import copy
12from functools import partial
23
34import pytest
45import torch
6+ import torch .nn .init as init
57from _test_utils .import_helper import skip_if_no_megatron
68from _test_utils .torch_dist .dist_utils import get_device_counts , spawn_multiprocess_job
79from _test_utils .torch_dist .plugins .megatron_common import (
5153 "*" : {
5254 "rank" : 32 ,
5355 "scale" : 1 ,
54- "lora_a_init" : "kaiming_init" ,
55- "lora_b_init" : "zero_init" ,
5656 "enable" : True ,
5757 },
5858 "*output_layer*" : {"enable" : False },
6666 "*" : {
6767 "rank" : 128 ,
6868 "scale" : 1 ,
69- "lora_a_init" : "kaiming_init" ,
70- "lora_b_init" : "zero_init" ,
7169 "enable" : True ,
7270 },
7371 "*output_layer*" : {"enable" : False },
8179 "*" : {
8280 "rank" : 32 ,
8381 "scale" : 1 ,
84- "lora_a_init" : "kaiming_init" ,
85- "lora_b_init" : "kaiming_init" ,
82+ "lora_a_init" : init . kaiming_uniform_ ,
83+ "lora_b_init" : init . kaiming_uniform_ ,
8684 "enable" : True ,
8785 },
8886 "*output_layer*" : {"enable" : False },
9694 "*" : {
9795 "rank" : 128 ,
9896 "scale" : 1 ,
99- "lora_a_init" : "kaiming_init" ,
100- "lora_b_init" : "kaiming_init" ,
97+ "lora_a_init" : init . kaiming_uniform_ ,
98+ "lora_b_init" : init . kaiming_uniform_ ,
10199 "enable" : True ,
102100 },
103101 "*output_layer*" : {"enable" : False },
111109 "*" : {
112110 "rank" : 8 ,
113111 "scale" : 1 ,
114- "lora_a_init" : "kaiming_init" ,
115- "lora_b_init" : "kaiming_init" ,
112+ "lora_a_init" : init . kaiming_uniform_ ,
113+ "lora_b_init" : init . kaiming_uniform_ ,
116114 "enable" : True ,
117115 },
118116 "*output_layer*" : {"enable" : False },
127125 "*self_attention*" : {
128126 "rank" : 16 ,
129127 "scale" : 1 ,
130- "lora_a_init" : "kaiming_init" ,
131- "lora_b_init" : "zero_init" ,
132128 "enable" : True ,
133129 },
134130 "*output_layer*" : {"enable" : False },
@@ -449,14 +445,15 @@ def test_adapter_gradient_flow_freeze_base_model(device_count, lora_config, tmp_
449445
450446def _test_adapter_gradient_flow_freeze_lora_model (lora_config , tmp_path , rank , size ):
451447 hidden_size = 512
452- lora_config ["freeze_lora_weights" ] = True
453- lora_config ["freeze_base_model" ] = False
448+ local_cfg = copy .deepcopy (lora_config )
449+ local_cfg ["freeze_lora_weights" ] = True
450+ local_cfg ["freeze_base_model" ] = False
454451
455452 initialize_for_megatron (tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 )
456453 model = _gpt_model_provider (tp_size = size , hidden_size = hidden_size )
457454 prompt_tokens = torch .randint (0 , model .vocab_size , (2 , model .max_sequence_length )).cuda ()
458455
459- mtpf .update_model (model , lora_config )
456+ mtpf .update_model (model , local_cfg )
460457 model .train ()
461458
462459 # Use a simple forward pass instead for grad check
@@ -569,7 +566,7 @@ def forward_func(mod):
569566 assert hasattr (module .weight_quantizer , "amax" )
570567 assert getattr (module .input_quantizer , "amax" ) is not None
571568 assert getattr (module .weight_quantizer , "amax" ) is not None
572- # Check if the lora have teh quantizer, they should not have them.
569+ # Check if the lora have the quantizer, they should not have them.
573570 for adapter_name in module ._lora_adapters :
574571 lora_a = module ._lora_adapters [adapter_name ]["lora_a" ]
575572 lora_b = module ._lora_adapters [adapter_name ]["lora_b" ]
@@ -621,7 +618,7 @@ def forward_func(mod):
621618 assert hasattr (module .weight_quantizer , "amax" )
622619 assert getattr (module .input_quantizer , "amax" ) is not None
623620 assert getattr (module .weight_quantizer , "amax" ) is not None
624- # Check if the lora have teh quantizer, they should not have them.
621+ # Check if the lora have the quantizer, they should not have them.
625622 for adapter_name in module ._lora_adapters :
626623 lora_a = module ._lora_adapters [adapter_name ]["lora_a" ]
627624 lora_b = module ._lora_adapters [adapter_name ]["lora_b" ]
@@ -701,7 +698,7 @@ def forward_func(mod):
701698 assert hasattr (module .weight_quantizer , "amax" )
702699 assert getattr (module .input_quantizer , "amax" ) is not None
703700 assert getattr (module .weight_quantizer , "amax" ) is not None
704- # Check if the lora have teh quantizer, they should not have them.
701+ # Check if the lora have the quantizer, they should not have them.
705702 for adapter_name in module ._lora_adapters :
706703 lora_a = module ._lora_adapters [adapter_name ]["lora_a" ]
707704 lora_b = module ._lora_adapters [adapter_name ]["lora_b" ]
@@ -765,7 +762,7 @@ def forward_func(mod):
765762 assert hasattr (module .weight_quantizer , "amax" )
766763 assert getattr (module .input_quantizer , "amax" ) is not None
767764 assert getattr (module .weight_quantizer , "amax" ) is not None
768- # Check if the lora have teh quantizer, they should not have them.
765+ # Check if the lora have the quantizer, they should not have them.
769766 for adapter_name in module ._lora_adapters :
770767 lora_a = module ._lora_adapters [adapter_name ]["lora_a" ]
771768 lora_b = module ._lora_adapters [adapter_name ]["lora_b" ]
@@ -784,7 +781,7 @@ def forward_func(mod):
784781 DEFAULT_LORA_CFG_RANDOM_INIT_TEST ,
785782 ],
786783)
787- def test_mcore_lora_quantize_save_restore (device_count , lora_config , tmp_path ):
784+ def test_mcore_lora_then_quantize_save_restore (device_count , lora_config , tmp_path ):
788785 spawn_multiprocess_job (
789786 size = device_count ,
790787 job = partial (_test_mcore_lora_then_quantize_save_restore , lora_config , str (tmp_path )),
0 commit comments