13
13
14
14
15
15
import modelopt .torch .peft as mtpf
16
- from modelopt .torch .peft .config import kaiming_init , zero_init
17
16
from modelopt .torch .peft .lora .layer import LoRAModule
18
17
from modelopt .torch .utils .plugins import megatron_prefill
19
18
24
23
"*" : {
25
24
"rank" : 32 ,
26
25
"scale" : 1 ,
27
- "lora_a_init" : kaiming_init ,
28
- "lora_b_init" : zero_init ,
26
+ "lora_a_init" : " kaiming_init" ,
27
+ "lora_b_init" : " zero_init" ,
29
28
"enable" : True ,
30
29
},
31
30
},
38
37
"*" : {
39
38
"rank" : 32 ,
40
39
"scale" : 1 ,
41
- "lora_a_init" : kaiming_init ,
42
- "lora_b_init" : kaiming_init ,
40
+ "lora_a_init" : " kaiming_init" ,
41
+ "lora_b_init" : " kaiming_init" ,
43
42
"enable" : True ,
44
43
},
45
44
},
46
45
}
47
46
48
- # Additional configurations for comprehensive testing
49
- SMALL_RANK_LORA_CFG = {
47
+ DEFAULT_LORA_CFG_RANDOM_INIT_SMALL_RANK_TEST = {
50
48
"adapter_type" : "lora" ,
51
- "adapter_name" : "small_rank " ,
49
+ "adapter_name" : "small " ,
52
50
"adapter_cfg" : {
53
51
"*" : {
54
- "rank" : 4 ,
52
+ "rank" : 8 ,
55
53
"scale" : 1 ,
56
- "lora_a_init" : kaiming_init ,
57
- "lora_b_init" : zero_init ,
54
+ "lora_a_init" : " kaiming_init" ,
55
+ "lora_b_init" : "kaiming_init" ,
58
56
"enable" : True ,
59
57
},
60
58
},
67
65
"*" : {
68
66
"rank" : 16 ,
69
67
"scale" : 10.0 ,
70
- "lora_a_init" : kaiming_init ,
71
- "lora_b_init" : zero_init ,
68
+ "lora_a_init" : " kaiming_init" ,
69
+ "lora_b_init" : " zero_init" ,
72
70
"enable" : True ,
73
71
},
74
72
},
78
76
"adapter_type" : "lora" ,
79
77
"adapter_name" : "selective" ,
80
78
"adapter_cfg" : {
81
- "*" : {"enable" : False }, # Disable by default
82
- "*self_attention*" : { # Enable only for self-attention layers
79
+ "*" : {"enable" : False },
80
+ "*self_attention*" : {
83
81
"rank" : 16 ,
84
82
"scale" : 1 ,
85
- "lora_a_init" : kaiming_init ,
86
- "lora_b_init" : zero_init ,
83
+ "lora_a_init" : " kaiming_init" ,
84
+ "lora_b_init" : " zero_init" ,
87
85
"enable" : True ,
88
86
},
89
87
},
@@ -131,41 +129,53 @@ def _test_forward_with_one_lora(lora_config, rank, size):
131
129
mtpf .update_model (model , lora_config )
132
130
lora_output = megatron_prefill (model , prompt_tokens )
133
131
assert lora_output .shape == original_output .shape
134
- if lora_config == DEFAULT_LORA_CFG_TEST :
132
+ if lora_config == DEFAULT_LORA_CFG_RANDOM_INIT_TEST :
133
+ assert not torch .allclose (lora_output , original_output , rtol = 1e-5 )
134
+ else :
135
135
assert torch .allclose (lora_output , original_output , rtol = 1e-5 ), (
136
136
f"{ lora_output } , { original_output } "
137
137
)
138
- else :
139
- assert not torch .allclose (lora_output , original_output , rtol = 1e-5 )
140
138
mtpf .disable_adapters (model )
141
139
lora_disabled_output = megatron_prefill (model , prompt_tokens )
142
140
assert torch .allclose (lora_disabled_output , original_output , rtol = 1e-5 )
143
141
mtpf .enable_adapters (model )
144
142
lora_reenabled_output = megatron_prefill (model , prompt_tokens )
145
143
assert torch .allclose (lora_reenabled_output , lora_output , rtol = 1e-5 )
146
144
lora_module_count = 0
145
+ lora_with_adapter_count = 0
147
146
for name , module in model .named_modules ():
148
147
if isinstance (module , LoRAModule ):
149
148
lora_module_count += 1
150
- assert hasattr (module , f"lora_a_{ lora_config ['adapter_name' ]} " )
151
- assert hasattr (module , f"lora_b_{ lora_config ['adapter_name' ]} " )
152
149
153
150
if lora_config == SELECTIVE_LAYER_LORA_CFG :
154
- if "self_attention" not in name :
155
- # These modules should have LoRA disabled
156
- assert not module ._lora_adapters [lora_config ["adapter_name" ]]["enable" ]
151
+ if "self_attention" in name :
152
+ # Only self_attention modules should have the adapter
153
+ assert hasattr (module , f"lora_a_{ lora_config ['adapter_name' ]} " )
154
+ assert hasattr (module , f"lora_b_{ lora_config ['adapter_name' ]} " )
155
+ assert lora_config ["adapter_name" ] in module ._lora_adapters
156
+ assert module ._lora_adapters [lora_config ["adapter_name" ]]["enable" ]
157
+ lora_with_adapter_count += 1
158
+ else :
159
+ # Other modules should NOT have the adapter at all
160
+ assert not hasattr (module , f"lora_a_{ lora_config ['adapter_name' ]} " )
161
+ assert not hasattr (module , f"lora_b_{ lora_config ['adapter_name' ]} " )
162
+ assert lora_config ["adapter_name" ] not in module ._lora_adapters
163
+ else :
164
+ # For non-selective configs, all LoRA modules should have the adapter
165
+ assert hasattr (module , f"lora_a_{ lora_config ['adapter_name' ]} " )
166
+ assert hasattr (module , f"lora_b_{ lora_config ['adapter_name' ]} " )
167
+ lora_with_adapter_count += 1
157
168
158
169
assert lora_module_count > 0
170
+ assert lora_with_adapter_count > 0
159
171
160
172
161
173
@pytest .mark .parametrize (
162
174
"lora_config" ,
163
175
[
164
176
DEFAULT_LORA_CFG_TEST ,
165
- # DEFAULT_LORA_CFG_RANDOM_INIT_TEST,
166
- # SMALL_RANK_LORA_CFG,
167
- # LARGE_SCALE_LORA_CFG,
168
- # SELECTIVE_LAYER_LORA_CFG,
177
+ DEFAULT_LORA_CFG_RANDOM_INIT_TEST ,
178
+ SELECTIVE_LAYER_LORA_CFG ,
169
179
],
170
180
)
171
181
def test_forward_with_one_lora (lora_config ):
@@ -174,7 +184,7 @@ def test_forward_with_one_lora(lora_config):
174
184
)
175
185
176
186
177
- def _test_forward_with_two_loras (lora_config_1 , lora_config_2 ):
187
+ def _test_forward_with_two_loras (lora_config_1 , lora_config_2 , rank , size ):
178
188
"""Test forward pass with two LoRA adapters and adapter switching."""
179
189
hidden_size = 320
180
190
initialize_for_megatron (tensor_model_parallel_size = 1 , pipeline_model_parallel_size = 1 )
@@ -183,21 +193,31 @@ def _test_forward_with_two_loras(lora_config_1, lora_config_2):
183
193
184
194
original_output = megatron_prefill (model , prompt_tokens )
185
195
mtpf .update_model (model , lora_config_1 )
196
+ # output from the first lora only
186
197
lora_1_output = megatron_prefill (model , prompt_tokens )
198
+
187
199
mtpf .update_model (model , lora_config_2 )
200
+
188
201
mtpf .disable_adapters (model , adapters_to_disable = [lora_config_1 ["adapter_name" ]])
189
202
mtpf .enable_adapters (model , adapters_to_enable = [lora_config_2 ["adapter_name" ]])
203
+
204
+ # output from the 2nd lora only
190
205
lora_2_output = megatron_prefill (model , prompt_tokens )
191
- if lora_config_1 != DEFAULT_LORA_CFG_TEST or lora_config_2 != DEFAULT_LORA_CFG_TEST :
192
- assert not torch .allclose (lora_1_output , lora_2_output , rtol = 1e-5 )
206
+
193
207
assert lora_1_output .shape == lora_2_output .shape
208
+ # Should not be the same
209
+ assert not torch .allclose (lora_1_output , lora_2_output )
210
+
194
211
mtpf .enable_adapters (model , adapters_to_enable = [lora_config_1 ["adapter_name" ]])
195
- mtpf .disable_adapters (model , adapters_to_disable = [lora_config_2 ["adapter_name" ]])
196
- switched_output = megatron_prefill (model , prompt_tokens )
197
- assert torch .allclose (switched_output , lora_1_output , rtol = 1e-5 )
212
+ mtpf .enable_adapters (model , adapters_to_enable = [lora_config_2 ["adapter_name" ]])
213
+ lora_all_output = megatron_prefill (model , prompt_tokens )
214
+
215
+ assert not torch .allclose (lora_all_output , lora_1_output )
216
+ assert not torch .allclose (lora_all_output , lora_2_output )
217
+
198
218
mtpf .disable_adapters (model )
199
219
both_disabled_output = megatron_prefill (model , prompt_tokens )
200
- assert torch .allclose (both_disabled_output , original_output , rtol = 1e-5 )
220
+ assert torch .allclose (both_disabled_output , original_output )
201
221
202
222
for _ , module in model .named_modules ():
203
223
if isinstance (module , LoRAModule ):
@@ -208,18 +228,18 @@ def _test_forward_with_two_loras(lora_config_1, lora_config_2):
208
228
assert len (module ._lora_adapters ) == 2
209
229
210
230
211
- # @pytest.mark.parametrize(
212
- # "lora_config_1, lora_config_2",
213
- # [
214
- # (DEFAULT_LORA_CFG_TEST, DEFAULT_LORA_CFG_RANDOM_INIT_TEST ),
215
- # (SMALL_RANK_LORA_CFG, LARGE_SCALE_LORA_CFG) ,
216
- # (DEFAULT_LORA_CFG_TEST, SELECTIVE_LAYER_LORA_CFG),
217
- # ],
218
- # )
219
- # def test_forward_with_two_loras(lora_config_1, lora_config_2):
220
- # spawn_multiprocess_job(
221
- # size=1, job=partial(_test_forward_with_two_loras, lora_config_1, lora_config_2), backend="nccl"
222
- # )
231
+ @pytest .mark .parametrize (
232
+ ( "lora_config_1" , " lora_config_2") ,
233
+ [
234
+ ( DEFAULT_LORA_CFG_RANDOM_INIT_TEST , DEFAULT_LORA_CFG_RANDOM_INIT_SMALL_RANK_TEST ),
235
+ ] ,
236
+ )
237
+ def test_forward_with_two_loras ( lora_config_1 , lora_config_2 ):
238
+ spawn_multiprocess_job (
239
+ size = 1 ,
240
+ job = partial ( _test_forward_with_two_loras , lora_config_1 , lora_config_2 ),
241
+ backend = "nccl" ,
242
+ )
223
243
224
244
225
245
# def test_edge_cases_and_error_handling():
0 commit comments