1
+ import copy
1
2
from functools import partial
2
3
3
4
import pytest
4
5
import torch
6
+ import torch .nn .init as init
5
7
from _test_utils .import_helper import skip_if_no_megatron
6
8
from _test_utils .torch_dist .dist_utils import get_device_counts , spawn_multiprocess_job
7
9
from _test_utils .torch_dist .plugins .megatron_common import (
51
53
"*" : {
52
54
"rank" : 32 ,
53
55
"scale" : 1 ,
54
- "lora_a_init" : "kaiming_init" ,
55
- "lora_b_init" : "zero_init" ,
56
56
"enable" : True ,
57
57
},
58
58
"*output_layer*" : {"enable" : False },
66
66
"*" : {
67
67
"rank" : 128 ,
68
68
"scale" : 1 ,
69
- "lora_a_init" : "kaiming_init" ,
70
- "lora_b_init" : "zero_init" ,
71
69
"enable" : True ,
72
70
},
73
71
"*output_layer*" : {"enable" : False },
81
79
"*" : {
82
80
"rank" : 32 ,
83
81
"scale" : 1 ,
84
- "lora_a_init" : "kaiming_init" ,
85
- "lora_b_init" : "kaiming_init" ,
82
+ "lora_a_init" : init . kaiming_uniform_ ,
83
+ "lora_b_init" : init . kaiming_uniform_ ,
86
84
"enable" : True ,
87
85
},
88
86
"*output_layer*" : {"enable" : False },
96
94
"*" : {
97
95
"rank" : 128 ,
98
96
"scale" : 1 ,
99
- "lora_a_init" : "kaiming_init" ,
100
- "lora_b_init" : "kaiming_init" ,
97
+ "lora_a_init" : init . kaiming_uniform_ ,
98
+ "lora_b_init" : init . kaiming_uniform_ ,
101
99
"enable" : True ,
102
100
},
103
101
"*output_layer*" : {"enable" : False },
111
109
"*" : {
112
110
"rank" : 8 ,
113
111
"scale" : 1 ,
114
- "lora_a_init" : "kaiming_init" ,
115
- "lora_b_init" : "kaiming_init" ,
112
+ "lora_a_init" : init . kaiming_uniform_ ,
113
+ "lora_b_init" : init . kaiming_uniform_ ,
116
114
"enable" : True ,
117
115
},
118
116
"*output_layer*" : {"enable" : False },
127
125
"*self_attention*" : {
128
126
"rank" : 16 ,
129
127
"scale" : 1 ,
130
- "lora_a_init" : "kaiming_init" ,
131
- "lora_b_init" : "zero_init" ,
132
128
"enable" : True ,
133
129
},
134
130
"*output_layer*" : {"enable" : False },
@@ -449,14 +445,15 @@ def test_adapter_gradient_flow_freeze_base_model(device_count, lora_config, tmp_
449
445
450
446
def _test_adapter_gradient_flow_freeze_lora_model (lora_config , tmp_path , rank , size ):
451
447
hidden_size = 512
452
- lora_config ["freeze_lora_weights" ] = True
453
- lora_config ["freeze_base_model" ] = False
448
+ local_cfg = copy .deepcopy (lora_config )
449
+ local_cfg ["freeze_lora_weights" ] = True
450
+ local_cfg ["freeze_base_model" ] = False
454
451
455
452
initialize_for_megatron (tensor_model_parallel_size = size , pipeline_model_parallel_size = 1 )
456
453
model = _gpt_model_provider (tp_size = size , hidden_size = hidden_size )
457
454
prompt_tokens = torch .randint (0 , model .vocab_size , (2 , model .max_sequence_length )).cuda ()
458
455
459
- mtpf .update_model (model , lora_config )
456
+ mtpf .update_model (model , local_cfg )
460
457
model .train ()
461
458
462
459
# Use a simple forward pass instead for grad check
@@ -569,7 +566,7 @@ def forward_func(mod):
569
566
assert hasattr (module .weight_quantizer , "amax" )
570
567
assert getattr (module .input_quantizer , "amax" ) is not None
571
568
assert getattr (module .weight_quantizer , "amax" ) is not None
572
- # Check if the lora have teh quantizer, they should not have them.
569
+ # Check if the lora have the quantizer, they should not have them.
573
570
for adapter_name in module ._lora_adapters :
574
571
lora_a = module ._lora_adapters [adapter_name ]["lora_a" ]
575
572
lora_b = module ._lora_adapters [adapter_name ]["lora_b" ]
@@ -621,7 +618,7 @@ def forward_func(mod):
621
618
assert hasattr (module .weight_quantizer , "amax" )
622
619
assert getattr (module .input_quantizer , "amax" ) is not None
623
620
assert getattr (module .weight_quantizer , "amax" ) is not None
624
- # Check if the lora have teh quantizer, they should not have them.
621
+ # Check if the lora have the quantizer, they should not have them.
625
622
for adapter_name in module ._lora_adapters :
626
623
lora_a = module ._lora_adapters [adapter_name ]["lora_a" ]
627
624
lora_b = module ._lora_adapters [adapter_name ]["lora_b" ]
@@ -701,7 +698,7 @@ def forward_func(mod):
701
698
assert hasattr (module .weight_quantizer , "amax" )
702
699
assert getattr (module .input_quantizer , "amax" ) is not None
703
700
assert getattr (module .weight_quantizer , "amax" ) is not None
704
- # Check if the lora have teh quantizer, they should not have them.
701
+ # Check if the lora have the quantizer, they should not have them.
705
702
for adapter_name in module ._lora_adapters :
706
703
lora_a = module ._lora_adapters [adapter_name ]["lora_a" ]
707
704
lora_b = module ._lora_adapters [adapter_name ]["lora_b" ]
@@ -765,7 +762,7 @@ def forward_func(mod):
765
762
assert hasattr (module .weight_quantizer , "amax" )
766
763
assert getattr (module .input_quantizer , "amax" ) is not None
767
764
assert getattr (module .weight_quantizer , "amax" ) is not None
768
- # Check if the lora have teh quantizer, they should not have them.
765
+ # Check if the lora have the quantizer, they should not have them.
769
766
for adapter_name in module ._lora_adapters :
770
767
lora_a = module ._lora_adapters [adapter_name ]["lora_a" ]
771
768
lora_b = module ._lora_adapters [adapter_name ]["lora_b" ]
@@ -784,7 +781,7 @@ def forward_func(mod):
784
781
DEFAULT_LORA_CFG_RANDOM_INIT_TEST ,
785
782
],
786
783
)
787
- def test_mcore_lora_quantize_save_restore (device_count , lora_config , tmp_path ):
784
+ def test_mcore_lora_then_quantize_save_restore (device_count , lora_config , tmp_path ):
788
785
spawn_multiprocess_job (
789
786
size = device_count ,
790
787
job = partial (_test_mcore_lora_then_quantize_save_restore , lora_config , str (tmp_path )),
0 commit comments