Skip to content

Commit 1d8ba41

Browse files
committed
minor on the test case
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 9311e04 commit 1d8ba41

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

tests/gpu/torch/peft/test_megatron_peft.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
skip_if_no_megatron()
1313

1414

15-
import modelopt.torch.peft as mtp
15+
import modelopt.torch.peft as mtpf
1616
from modelopt.torch.peft.config import kaiming_init, zero_init
1717
from modelopt.torch.peft.lora.layer import LoRAModule
1818
from modelopt.torch.utils.plugins import megatron_prefill
@@ -128,7 +128,7 @@ def _test_forward_with_one_lora(lora_config, rank, size):
128128
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
129129

130130
original_output = megatron_prefill(model, prompt_tokens)
131-
mtp.update_model(model, lora_config)
131+
mtpf.update_model(model, lora_config)
132132
lora_output = megatron_prefill(model, prompt_tokens)
133133
assert lora_output.shape == original_output.shape
134134
if lora_config == DEFAULT_LORA_CFG_TEST:
@@ -137,10 +137,10 @@ def _test_forward_with_one_lora(lora_config, rank, size):
137137
)
138138
else:
139139
assert not torch.allclose(lora_output, original_output, rtol=1e-5)
140-
mtp.disable_adapters(model)
140+
mtpf.disable_adapters(model)
141141
lora_disabled_output = megatron_prefill(model, prompt_tokens)
142142
assert torch.allclose(lora_disabled_output, original_output, rtol=1e-5)
143-
mtp.enable_adapters(model)
143+
mtpf.enable_adapters(model)
144144
lora_reenabled_output = megatron_prefill(model, prompt_tokens)
145145
assert torch.allclose(lora_reenabled_output, lora_output, rtol=1e-5)
146146
lora_module_count = 0
@@ -182,20 +182,20 @@ def _test_forward_with_two_loras(lora_config_1, lora_config_2):
182182
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
183183

184184
original_output = megatron_prefill(model, prompt_tokens)
185-
mtp.update_model(model, lora_config_1)
185+
mtpf.update_model(model, lora_config_1)
186186
lora_1_output = megatron_prefill(model, prompt_tokens)
187-
mtp.update_model(model, lora_config_2)
188-
mtp.disable_adapters(model, adapters_to_disable=[lora_config_1["adapter_name"]])
189-
mtp.enable_adapters(model, adapters_to_enable=[lora_config_2["adapter_name"]])
187+
mtpf.update_model(model, lora_config_2)
188+
mtpf.disable_adapters(model, adapters_to_disable=[lora_config_1["adapter_name"]])
189+
mtpf.enable_adapters(model, adapters_to_enable=[lora_config_2["adapter_name"]])
190190
lora_2_output = megatron_prefill(model, prompt_tokens)
191191
if lora_config_1 != DEFAULT_LORA_CFG_TEST or lora_config_2 != DEFAULT_LORA_CFG_TEST:
192192
assert not torch.allclose(lora_1_output, lora_2_output, rtol=1e-5)
193193
assert lora_1_output.shape == lora_2_output.shape
194-
mtp.enable_adapters(model, adapters_to_enable=[lora_config_1["adapter_name"]])
195-
mtp.disable_adapters(model, adapters_to_disable=[lora_config_2["adapter_name"]])
194+
mtpf.enable_adapters(model, adapters_to_enable=[lora_config_1["adapter_name"]])
195+
mtpf.disable_adapters(model, adapters_to_disable=[lora_config_2["adapter_name"]])
196196
switched_output = megatron_prefill(model, prompt_tokens)
197197
assert torch.allclose(switched_output, lora_1_output, rtol=1e-5)
198-
mtp.disable_adapters(model)
198+
mtpf.disable_adapters(model)
199199
both_disabled_output = megatron_prefill(model, prompt_tokens)
200200
assert torch.allclose(both_disabled_output, original_output, rtol=1e-5)
201201

0 commit comments

Comments
 (0)