12
12
skip_if_no_megatron ()
13
13
14
14
15
- import modelopt .torch .peft as mtp
15
+ import modelopt .torch .peft as mtpf
16
16
from modelopt .torch .peft .config import kaiming_init , zero_init
17
17
from modelopt .torch .peft .lora .layer import LoRAModule
18
18
from modelopt .torch .utils .plugins import megatron_prefill
@@ -128,7 +128,7 @@ def _test_forward_with_one_lora(lora_config, rank, size):
128
128
prompt_tokens = torch .randint (0 , model .vocab_size , (2 , model .max_sequence_length )).cuda ()
129
129
130
130
original_output = megatron_prefill (model , prompt_tokens )
131
- mtp .update_model (model , lora_config )
131
+ mtpf .update_model (model , lora_config )
132
132
lora_output = megatron_prefill (model , prompt_tokens )
133
133
assert lora_output .shape == original_output .shape
134
134
if lora_config == DEFAULT_LORA_CFG_TEST :
@@ -137,10 +137,10 @@ def _test_forward_with_one_lora(lora_config, rank, size):
137
137
)
138
138
else :
139
139
assert not torch .allclose (lora_output , original_output , rtol = 1e-5 )
140
- mtp .disable_adapters (model )
140
+ mtpf .disable_adapters (model )
141
141
lora_disabled_output = megatron_prefill (model , prompt_tokens )
142
142
assert torch .allclose (lora_disabled_output , original_output , rtol = 1e-5 )
143
- mtp .enable_adapters (model )
143
+ mtpf .enable_adapters (model )
144
144
lora_reenabled_output = megatron_prefill (model , prompt_tokens )
145
145
assert torch .allclose (lora_reenabled_output , lora_output , rtol = 1e-5 )
146
146
lora_module_count = 0
@@ -182,20 +182,20 @@ def _test_forward_with_two_loras(lora_config_1, lora_config_2):
182
182
prompt_tokens = torch .randint (0 , model .vocab_size , (2 , model .max_sequence_length )).cuda ()
183
183
184
184
original_output = megatron_prefill (model , prompt_tokens )
185
- mtp .update_model (model , lora_config_1 )
185
+ mtpf .update_model (model , lora_config_1 )
186
186
lora_1_output = megatron_prefill (model , prompt_tokens )
187
- mtp .update_model (model , lora_config_2 )
188
- mtp .disable_adapters (model , adapters_to_disable = [lora_config_1 ["adapter_name" ]])
189
- mtp .enable_adapters (model , adapters_to_enable = [lora_config_2 ["adapter_name" ]])
187
+ mtpf .update_model (model , lora_config_2 )
188
+ mtpf .disable_adapters (model , adapters_to_disable = [lora_config_1 ["adapter_name" ]])
189
+ mtpf .enable_adapters (model , adapters_to_enable = [lora_config_2 ["adapter_name" ]])
190
190
lora_2_output = megatron_prefill (model , prompt_tokens )
191
191
if lora_config_1 != DEFAULT_LORA_CFG_TEST or lora_config_2 != DEFAULT_LORA_CFG_TEST :
192
192
assert not torch .allclose (lora_1_output , lora_2_output , rtol = 1e-5 )
193
193
assert lora_1_output .shape == lora_2_output .shape
194
- mtp .enable_adapters (model , adapters_to_enable = [lora_config_1 ["adapter_name" ]])
195
- mtp .disable_adapters (model , adapters_to_disable = [lora_config_2 ["adapter_name" ]])
194
+ mtpf .enable_adapters (model , adapters_to_enable = [lora_config_1 ["adapter_name" ]])
195
+ mtpf .disable_adapters (model , adapters_to_disable = [lora_config_2 ["adapter_name" ]])
196
196
switched_output = megatron_prefill (model , prompt_tokens )
197
197
assert torch .allclose (switched_output , lora_1_output , rtol = 1e-5 )
198
- mtp .disable_adapters (model )
198
+ mtpf .disable_adapters (model )
199
199
both_disabled_output = megatron_prefill (model , prompt_tokens )
200
200
assert torch .allclose (both_disabled_output , original_output , rtol = 1e-5 )
201
201
0 commit comments