Skip to content

Commit a2cea26

Browse files
committed
Update test cases / update namings
Signed-off-by: Jingyu Xin <[email protected]>
1 parent d6b2e60 commit a2cea26

File tree

2 files changed

+166
-4
lines changed

2 files changed

+166
-4
lines changed

modelopt/torch/peft/config.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727
__all__ = ["ExportPEFTConfig", "PEFTAttributeConfig", "PEFTConfig"]
2828

2929

30-
def default_lora_a_init(weight):
30+
def kaiming_init(weight):
3131
"""Default initialization for LoRA A matrix using Kaiming uniform."""
3232
return init.kaiming_uniform_(weight, a=math.sqrt(5))
3333

3434

35-
def default_lora_b_init(weight):
35+
def zero_init(weight):
3636
"""Default initialization for LoRA B matrix using zeros."""
3737
return init.zeros_(weight)
3838

@@ -62,13 +62,13 @@ class PEFTAttributeConfig(ModeloptBaseConfig):
6262
)
6363

6464
lora_a_init: Callable[[object], None] | None = ModeloptField(
65-
default=default_lora_a_init,
65+
default=kaiming_init,
6666
title="LoRA A matrix initializer",
6767
description="Custom initialization function for LoRA A matrix. Default to Kaiming uniform initialization.",
6868
)
6969

7070
lora_b_init: Callable[[object], None] | None = ModeloptField(
71-
default=default_lora_b_init,
71+
default=zero_init,
7272
title="LoRA B matrix initializer",
7373
description="Custom initialization function for LoRA B matrix. Default to zero initialization.",
7474
)
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import pytest
2+
import torch
3+
from _test_utils.import_helper import skip_if_no_megatron
4+
from _test_utils.torch_dist.plugins.megatron_common import (
5+
get_mcore_gpt_model,
6+
initialize_for_megatron,
7+
)
8+
9+
skip_if_no_megatron()
10+
11+
12+
import modelopt.torch.peft as mtp
13+
from modelopt.torch.peft.config import kaiming_init, zero_init
14+
from modelopt.torch.peft.lora.layer import LoRAModule
15+
from modelopt.torch.utils.plugins import megatron_prefill
16+
17+
DEFAULT_LORA_CFG_TEST = {
18+
"adapter_type": "lora",
19+
"adapter_name": "default",
20+
"adapter_cfg": {
21+
"*": {
22+
"rank": 32,
23+
"scale": 1,
24+
"lora_a_init": kaiming_init,
25+
"lora_b_init": zero_init,
26+
"enable": True,
27+
},
28+
},
29+
}
30+
31+
DEFAULT_LORA_CFG_RANDOM_INIT_TEST = {
32+
"adapter_type": "lora",
33+
"adapter_name": "random",
34+
"adapter_cfg": {
35+
"*": {
36+
"rank": 32,
37+
"scale": 1,
38+
"lora_a_init": kaiming_init,
39+
"lora_b_init": kaiming_init,
40+
"enable": True,
41+
},
42+
},
43+
}
44+
45+
46+
def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False):
47+
"""Build the model."""
48+
49+
if meta_device:
50+
with torch.device("meta"):
51+
gpt_model = get_mcore_gpt_model(
52+
tensor_model_parallel_size=tp_size,
53+
num_layers=4,
54+
ffn_hidden_size=None,
55+
num_attention_heads=4,
56+
activation_func="squared_relu",
57+
transformer_impl="local",
58+
hidden_size=hidden_size,
59+
vocab_size=vocab_size,
60+
use_cpu_initialization=meta_device,
61+
)
62+
else:
63+
gpt_model = get_mcore_gpt_model(
64+
tensor_model_parallel_size=tp_size,
65+
num_layers=4,
66+
ffn_hidden_size=None,
67+
num_attention_heads=4,
68+
activation_func="squared_relu",
69+
transformer_impl="local",
70+
hidden_size=hidden_size,
71+
vocab_size=vocab_size,
72+
).cuda()
73+
return gpt_model.eval()
74+
75+
76+
@pytest.mark.parametrize(
77+
"lora_config",
78+
[
79+
DEFAULT_LORA_CFG_TEST,
80+
DEFAULT_LORA_CFG_RANDOM_INIT_TEST,
81+
],
82+
)
83+
def test_forward_with_one_lora(lora_config):
84+
hidden_size = 320
85+
initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1)
86+
model = _gpt_model_provider(tp_size=1, hidden_size=hidden_size)
87+
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
88+
original_output = megatron_prefill(model, prompt_tokens)
89+
mtp.update_model(model, lora_config)
90+
lora_output = megatron_prefill(model, prompt_tokens)
91+
assert lora_output.shape == original_output.shape
92+
if lora_config == DEFAULT_LORA_CFG_TEST:
93+
assert torch.allclose(lora_output, original_output)
94+
else:
95+
assert not torch.allclose(lora_output, original_output)
96+
97+
mtp.disable_adapters(model)
98+
lora_disabled_output = megatron_prefill(model, prompt_tokens)
99+
assert torch.allclose(lora_disabled_output, original_output)
100+
101+
for _, module in model.named_modules():
102+
if isinstance(module, LoRAModule):
103+
assert hasattr(module, f"lora_a_{lora_config['adapter_name']}")
104+
assert hasattr(module, f"lora_b_{lora_config['adapter_name']}")
105+
106+
107+
@pytest.mark.parametrize(
108+
"lora_config_1",
109+
[
110+
DEFAULT_LORA_CFG_TEST,
111+
],
112+
)
113+
@pytest.mark.parametrize(
114+
"lora_config_2",
115+
[
116+
DEFAULT_LORA_CFG_RANDOM_INIT_TEST,
117+
],
118+
)
119+
def test_forward_with_two_loras(lora_config_1, lora_config_2):
120+
hidden_size = 320
121+
initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1)
122+
model = _gpt_model_provider(tp_size=1, hidden_size=hidden_size)
123+
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
124+
mtp.update_model(model, lora_config_1)
125+
lora_1_output = megatron_prefill(model, prompt_tokens)
126+
mtp.update_model(model, lora_config_2)
127+
lora_2_output = megatron_prefill(model, prompt_tokens)
128+
129+
assert not torch.allclose(lora_1_output, lora_2_output)
130+
assert lora_1_output.shape == lora_2_output.shape
131+
132+
for _, module in model.named_modules():
133+
if isinstance(module, LoRAModule):
134+
assert hasattr(module, f"lora_a_{lora_config_1['adapter_name']}")
135+
assert hasattr(module, f"lora_b_{lora_config_1['adapter_name']}")
136+
137+
assert hasattr(module, f"lora_a_{lora_config_2['adapter_name']}")
138+
assert hasattr(module, f"lora_b_{lora_config_2['adapter_name']}")
139+
140+
141+
def test_forward_with_lora_quantize():
142+
pass
143+
144+
145+
def test_forward_with_quantize_lora():
146+
pass
147+
148+
149+
def test_one_lora_save_restore():
150+
pass
151+
152+
153+
def test_two_loras_save_restore():
154+
pass
155+
156+
157+
def test_one_lora_quantize_save_restore():
158+
pass
159+
160+
161+
def test_two_loras_quantize_save_restore():
162+
pass

0 commit comments

Comments
 (0)