Skip to content

Commit 8df12bc

Browse files
committed
Update the test case and some minor updates
Signed-off-by: Jingyu Xin <[email protected]>
1 parent f98711e commit 8df12bc

File tree

3 files changed

+83
-64
lines changed

3 files changed

+83
-64
lines changed

modelopt/torch/peft/config.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,15 @@
1515

1616
"""Configuration classes for PEFT methods."""
1717

18-
import math
1918
from collections.abc import Callable
2019

21-
import torch.nn.init as init
2220
from pydantic import field_validator
2321

2422
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
2523

2624
__all__ = ["ExportPEFTConfig", "PEFTAttributeConfig", "PEFTConfig"]
2725

2826

29-
def kaiming_init(weight):
30-
"""Default initialization for LoRA A matrix using Kaiming uniform."""
31-
return init.kaiming_uniform_(weight, a=math.sqrt(5))
32-
33-
34-
def zero_init(weight):
35-
"""Default initialization for LoRA B matrix using zeros."""
36-
return init.zeros_(weight)
37-
38-
3927
class PEFTAttributeConfig(ModeloptBaseConfig):
4028
"""Configuration for PEFT adapter attributes."""
4129

modelopt/torch/peft/conversion.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,23 @@ def add_adapter(model, config: PEFTConfig):
106106

107107
for name, module in model.named_modules():
108108
if isinstance(module, LoRAModule):
109+
# Collect all matching adapter settings and merge them
110+
# Later patterns override earlier ones
111+
merged_setting = None
109112
for wildcard_or_filter_func, adapter_setting in adapter_cfg.items():
110113
if _matches(name, wildcard_or_filter_func):
111-
module.update_layer_lora(
112-
adapter_name,
113-
adapter_setting,
114-
)
114+
if merged_setting is None:
115+
merged_setting = adapter_setting.copy()
116+
else:
117+
merged_setting.update(adapter_setting)
118+
119+
# Only call update_layer_lora if we have settings and enable is not False
120+
# If enable=False, skip adding the adapter entirely
121+
if merged_setting is not None and merged_setting.get("enable", True):
122+
module.update_layer_lora(
123+
adapter_name,
124+
merged_setting,
125+
)
115126

116127
return model
117128

tests/gpu/torch/peft/test_megatron_peft.py

Lines changed: 68 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414

1515
import modelopt.torch.peft as mtpf
16-
from modelopt.torch.peft.config import kaiming_init, zero_init
1716
from modelopt.torch.peft.lora.layer import LoRAModule
1817
from modelopt.torch.utils.plugins import megatron_prefill
1918

@@ -24,8 +23,8 @@
2423
"*": {
2524
"rank": 32,
2625
"scale": 1,
27-
"lora_a_init": kaiming_init,
28-
"lora_b_init": zero_init,
26+
"lora_a_init": "kaiming_init",
27+
"lora_b_init": "zero_init",
2928
"enable": True,
3029
},
3130
},
@@ -38,23 +37,22 @@
3837
"*": {
3938
"rank": 32,
4039
"scale": 1,
41-
"lora_a_init": kaiming_init,
42-
"lora_b_init": kaiming_init,
40+
"lora_a_init": "kaiming_init",
41+
"lora_b_init": "kaiming_init",
4342
"enable": True,
4443
},
4544
},
4645
}
4746

48-
# Additional configurations for comprehensive testing
49-
SMALL_RANK_LORA_CFG = {
47+
DEFAULT_LORA_CFG_RANDOM_INIT_SMALL_RANK_TEST = {
5048
"adapter_type": "lora",
51-
"adapter_name": "small_rank",
49+
"adapter_name": "small",
5250
"adapter_cfg": {
5351
"*": {
54-
"rank": 4,
52+
"rank": 8,
5553
"scale": 1,
56-
"lora_a_init": kaiming_init,
57-
"lora_b_init": zero_init,
54+
"lora_a_init": "kaiming_init",
55+
"lora_b_init": "kaiming_init",
5856
"enable": True,
5957
},
6058
},
@@ -67,8 +65,8 @@
6765
"*": {
6866
"rank": 16,
6967
"scale": 10.0,
70-
"lora_a_init": kaiming_init,
71-
"lora_b_init": zero_init,
68+
"lora_a_init": "kaiming_init",
69+
"lora_b_init": "zero_init",
7270
"enable": True,
7371
},
7472
},
@@ -78,12 +76,12 @@
7876
"adapter_type": "lora",
7977
"adapter_name": "selective",
8078
"adapter_cfg": {
81-
"*": {"enable": False}, # Disable by default
82-
"*self_attention*": { # Enable only for self-attention layers
79+
"*": {"enable": False},
80+
"*self_attention*": {
8381
"rank": 16,
8482
"scale": 1,
85-
"lora_a_init": kaiming_init,
86-
"lora_b_init": zero_init,
83+
"lora_a_init": "kaiming_init",
84+
"lora_b_init": "zero_init",
8785
"enable": True,
8886
},
8987
},
@@ -131,41 +129,53 @@ def _test_forward_with_one_lora(lora_config, rank, size):
131129
mtpf.update_model(model, lora_config)
132130
lora_output = megatron_prefill(model, prompt_tokens)
133131
assert lora_output.shape == original_output.shape
134-
if lora_config == DEFAULT_LORA_CFG_TEST:
132+
if lora_config == DEFAULT_LORA_CFG_RANDOM_INIT_TEST:
133+
assert not torch.allclose(lora_output, original_output, rtol=1e-5)
134+
else:
135135
assert torch.allclose(lora_output, original_output, rtol=1e-5), (
136136
f"{lora_output}, {original_output}"
137137
)
138-
else:
139-
assert not torch.allclose(lora_output, original_output, rtol=1e-5)
140138
mtpf.disable_adapters(model)
141139
lora_disabled_output = megatron_prefill(model, prompt_tokens)
142140
assert torch.allclose(lora_disabled_output, original_output, rtol=1e-5)
143141
mtpf.enable_adapters(model)
144142
lora_reenabled_output = megatron_prefill(model, prompt_tokens)
145143
assert torch.allclose(lora_reenabled_output, lora_output, rtol=1e-5)
146144
lora_module_count = 0
145+
lora_with_adapter_count = 0
147146
for name, module in model.named_modules():
148147
if isinstance(module, LoRAModule):
149148
lora_module_count += 1
150-
assert hasattr(module, f"lora_a_{lora_config['adapter_name']}")
151-
assert hasattr(module, f"lora_b_{lora_config['adapter_name']}")
152149

153150
if lora_config == SELECTIVE_LAYER_LORA_CFG:
154-
if "self_attention" not in name:
155-
# These modules should have LoRA disabled
156-
assert not module._lora_adapters[lora_config["adapter_name"]]["enable"]
151+
if "self_attention" in name:
152+
# Only self_attention modules should have the adapter
153+
assert hasattr(module, f"lora_a_{lora_config['adapter_name']}")
154+
assert hasattr(module, f"lora_b_{lora_config['adapter_name']}")
155+
assert lora_config["adapter_name"] in module._lora_adapters
156+
assert module._lora_adapters[lora_config["adapter_name"]]["enable"]
157+
lora_with_adapter_count += 1
158+
else:
159+
# Other modules should NOT have the adapter at all
160+
assert not hasattr(module, f"lora_a_{lora_config['adapter_name']}")
161+
assert not hasattr(module, f"lora_b_{lora_config['adapter_name']}")
162+
assert lora_config["adapter_name"] not in module._lora_adapters
163+
else:
164+
# For non-selective configs, all LoRA modules should have the adapter
165+
assert hasattr(module, f"lora_a_{lora_config['adapter_name']}")
166+
assert hasattr(module, f"lora_b_{lora_config['adapter_name']}")
167+
lora_with_adapter_count += 1
157168

158169
assert lora_module_count > 0
170+
assert lora_with_adapter_count > 0
159171

160172

161173
@pytest.mark.parametrize(
162174
"lora_config",
163175
[
164176
DEFAULT_LORA_CFG_TEST,
165-
# DEFAULT_LORA_CFG_RANDOM_INIT_TEST,
166-
# SMALL_RANK_LORA_CFG,
167-
# LARGE_SCALE_LORA_CFG,
168-
# SELECTIVE_LAYER_LORA_CFG,
177+
DEFAULT_LORA_CFG_RANDOM_INIT_TEST,
178+
SELECTIVE_LAYER_LORA_CFG,
169179
],
170180
)
171181
def test_forward_with_one_lora(lora_config):
@@ -174,7 +184,7 @@ def test_forward_with_one_lora(lora_config):
174184
)
175185

176186

177-
def _test_forward_with_two_loras(lora_config_1, lora_config_2):
187+
def _test_forward_with_two_loras(lora_config_1, lora_config_2, rank, size):
178188
"""Test forward pass with two LoRA adapters and adapter switching."""
179189
hidden_size = 320
180190
initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1)
@@ -183,21 +193,31 @@ def _test_forward_with_two_loras(lora_config_1, lora_config_2):
183193

184194
original_output = megatron_prefill(model, prompt_tokens)
185195
mtpf.update_model(model, lora_config_1)
196+
# output from the first lora only
186197
lora_1_output = megatron_prefill(model, prompt_tokens)
198+
187199
mtpf.update_model(model, lora_config_2)
200+
188201
mtpf.disable_adapters(model, adapters_to_disable=[lora_config_1["adapter_name"]])
189202
mtpf.enable_adapters(model, adapters_to_enable=[lora_config_2["adapter_name"]])
203+
204+
# output from the 2nd lora only
190205
lora_2_output = megatron_prefill(model, prompt_tokens)
191-
if lora_config_1 != DEFAULT_LORA_CFG_TEST or lora_config_2 != DEFAULT_LORA_CFG_TEST:
192-
assert not torch.allclose(lora_1_output, lora_2_output, rtol=1e-5)
206+
193207
assert lora_1_output.shape == lora_2_output.shape
208+
# Should not be the same
209+
assert not torch.allclose(lora_1_output, lora_2_output)
210+
194211
mtpf.enable_adapters(model, adapters_to_enable=[lora_config_1["adapter_name"]])
195-
mtpf.disable_adapters(model, adapters_to_disable=[lora_config_2["adapter_name"]])
196-
switched_output = megatron_prefill(model, prompt_tokens)
197-
assert torch.allclose(switched_output, lora_1_output, rtol=1e-5)
212+
mtpf.enable_adapters(model, adapters_to_enable=[lora_config_2["adapter_name"]])
213+
lora_all_output = megatron_prefill(model, prompt_tokens)
214+
215+
assert not torch.allclose(lora_all_output, lora_1_output)
216+
assert not torch.allclose(lora_all_output, lora_2_output)
217+
198218
mtpf.disable_adapters(model)
199219
both_disabled_output = megatron_prefill(model, prompt_tokens)
200-
assert torch.allclose(both_disabled_output, original_output, rtol=1e-5)
220+
assert torch.allclose(both_disabled_output, original_output)
201221

202222
for _, module in model.named_modules():
203223
if isinstance(module, LoRAModule):
@@ -208,18 +228,18 @@ def _test_forward_with_two_loras(lora_config_1, lora_config_2):
208228
assert len(module._lora_adapters) == 2
209229

210230

211-
# @pytest.mark.parametrize(
212-
# "lora_config_1,lora_config_2",
213-
# [
214-
# (DEFAULT_LORA_CFG_TEST, DEFAULT_LORA_CFG_RANDOM_INIT_TEST),
215-
# (SMALL_RANK_LORA_CFG, LARGE_SCALE_LORA_CFG),
216-
# (DEFAULT_LORA_CFG_TEST, SELECTIVE_LAYER_LORA_CFG),
217-
# ],
218-
# )
219-
# def test_forward_with_two_loras(lora_config_1, lora_config_2):
220-
# spawn_multiprocess_job(
221-
# size=1, job=partial(_test_forward_with_two_loras, lora_config_1, lora_config_2), backend="nccl"
222-
# )
231+
@pytest.mark.parametrize(
232+
("lora_config_1", "lora_config_2"),
233+
[
234+
(DEFAULT_LORA_CFG_RANDOM_INIT_TEST, DEFAULT_LORA_CFG_RANDOM_INIT_SMALL_RANK_TEST),
235+
],
236+
)
237+
def test_forward_with_two_loras(lora_config_1, lora_config_2):
238+
spawn_multiprocess_job(
239+
size=1,
240+
job=partial(_test_forward_with_two_loras, lora_config_1, lora_config_2),
241+
backend="nccl",
242+
)
223243

224244

225245
# def test_edge_cases_and_error_handling():

0 commit comments

Comments
 (0)