Skip to content

Commit 0b310fb

Browse files
committed
update init functions
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 5030b43 commit 0b310fb

File tree

5 files changed

+55
-60
lines changed

5 files changed

+55
-60
lines changed

modelopt/torch/peft/config.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,19 @@
1515

1616
"""Configuration classes for PEFT methods."""
1717

18+
import inspect
1819
from collections.abc import Callable
1920

21+
import torch.nn.init
2022
from pydantic import field_validator
23+
from torch import Tensor
2124

2225
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
2326

2427
__all__ = ["ExportPEFTConfig", "PEFTAttributeConfig", "PEFTConfig"]
2528

29+
InitFn = Callable[..., Tensor]
30+
2631

2732
class PEFTAttributeConfig(ModeloptBaseConfig):
2833
"""Configuration for PEFT adapter attributes."""
@@ -48,26 +53,43 @@ class PEFTAttributeConfig(ModeloptBaseConfig):
4853
description="Scaling factor for the LoRA output. Controls the magnitude of the adaptation.",
4954
)
5055

51-
lora_a_init: str = ModeloptField(
52-
default="kaiming_init",
56+
lora_a_init: InitFn = ModeloptField(
57+
default=torch.nn.init.kaiming_uniform_,
5358
title="LoRA A matrix initializer",
54-
description="Custom initialization function for LoRA A matrix. Default to Kaiming uniform initialization.",
59+
description="Custom initialization function for LoRA A matrix. \
60+
Default to Kaiming uniform initialization. For more init methods \
61+
you can refer to https://docs.pytorch.org/docs/stable/nn.init.html",
5562
)
5663

57-
lora_b_init: str = ModeloptField(
58-
default="zero_init",
64+
lora_b_init: InitFn = ModeloptField(
65+
default=torch.nn.init.zeros_,
5966
title="LoRA B matrix initializer",
60-
description="Custom initialization function for LoRA B matrix. Default to zero initialization.",
67+
description="Custom initialization function for LoRA B matrix. Default to zero initialization. \
68+
For more init methods you can refer to https://docs.pytorch.org/docs/stable/nn.init.html",
6169
)
6270

6371
@field_validator("lora_a_init", "lora_b_init")
6472
@classmethod
6573
def validate_init_method(cls, v):
6674
"""Validate initialization method is supported."""
67-
valid_methods = {"kaiming_init", "zero_init"}
68-
if v not in valid_methods:
75+
if callable(v):
76+
# Check if this is a function from torch.nn.init
77+
module = inspect.getmodule(v)
78+
if module is not torch.nn.init:
79+
raise ValueError(
80+
f"Callable initialization method must be from torch.nn.init module, "
81+
f"got function from {module.__name__ if module else 'unknown module'}"
82+
)
83+
func_name = getattr(v, "__name__", "")
84+
if not func_name.endswith("_"):
85+
raise ValueError(
86+
f"Initialization method must be an in-place function (name should end with '_'), "
87+
f"got '{func_name}'. For example,"
88+
f" use torch.nn.init.kaiming_uniform_ instead of torch.nn.init.kaiming_uniform"
89+
)
90+
else:
6991
raise ValueError(
70-
f"Invalid initialization method: {v}. Supported methods: {', '.join(valid_methods)}"
92+
f"Initialization method must be a callable function from torch.nn.init, got {type(v)}"
7193
)
7294
return v
7395

modelopt/torch/peft/conversion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .config import PEFTConfig
2727
from .lora.layer import LoRAModule, LoRAModuleRegistry
2828

29+
# TODO: Add test cases to cover these functions
2930
__all__ = [
3031
"freeze_base_weights",
3132
"freeze_lora_weights",

modelopt/torch/peft/lora/layer.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,16 @@
11
"""LoRA (Low-Rank Adaptation) module implementation."""
22

3-
import math
43
from abc import abstractmethod
54
from typing import Any
65

76
import torch
87
import torch.nn as nn
9-
import torch.nn.init as init
108

119
from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls
1210

1311
from ..config import PEFTAttributeConfig
1412

15-
__all__ = ["LoRAModule", "LoRAModuleRegistry", "get_init_methods"]
16-
17-
18-
def get_init_methods(init_method: str = "kaiming_init"):
19-
"""Get the target init method for the lora a and lora b weights.
20-
21-
Args:
22-
init_method: the init method you want for the lora layer
23-
"""
24-
if init_method == "kaiming_init":
25-
return lambda weight: init.kaiming_uniform_(
26-
weight, a=math.sqrt(5)
27-
) # LoRA A: Kaiming uniform
28-
elif init_method == "zero_init":
29-
return lambda weight: init.zeros_(weight) # LoRA B: zeros
30-
else:
31-
raise ValueError(
32-
f"Unsupported initialization method: '{init_method}'. "
33-
"Supported methods: 'kaiming_init', 'zero_init'"
34-
)
13+
__all__ = ["LoRAModule", "LoRAModuleRegistry"]
3514

3615

3716
class LoRAModule(DynamicModule):

modelopt/torch/peft/lora/plugins/megatron.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from ...config import PEFTAttributeConfig
3333
from ...custom import CUSTOM_MODEL_PLUGINS
34-
from ..layer import LoRAModule, LoRAModuleRegistry, get_init_methods
34+
from ..layer import LoRAModule, LoRAModuleRegistry
3535

3636
DEFAULT_LORA_RANK = 64
3737
DEFAULT_SCALE = 1.0
@@ -130,23 +130,21 @@ def update_layer_lora(
130130
adapter_name: Name for the new adapter
131131
rank: Rank of the LoRA decomposition
132132
"""
133-
lora_a_init = get_init_methods(attr_config.lora_a_init)
134-
lora_b_init = get_init_methods(attr_config.lora_b_init)
135133
lora_a = nn.Linear(
136134
in_features=self.input_size,
137135
out_features=attr_config.rank,
138136
bias=False,
139137
)
140138
with torch.no_grad():
141-
lora_a_init(lora_a.weight)
139+
attr_config.lora_a_init(lora_a.weight)
142140

143141
lora_b = ColumnParallelLinear(
144142
attr_config.rank,
145143
self.output_size,
146144
config=self.config,
147145
bias=False,
148146
gather_output=False,
149-
init_method=lora_b_init,
147+
init_method=attr_config.lora_b_init,
150148
)
151149

152150
self._register_adapter_with_device(
@@ -204,16 +202,14 @@ def update_layer_lora(
204202
adapter_name: Name for the new adapter
205203
rank: Rank of the LoRA decomposition
206204
"""
207-
lora_a_init = get_init_methods(attr_config.lora_a_init)
208-
lora_b_init = get_init_methods(attr_config.lora_b_init)
209205
lora_a = RowParallelLinear(
210206
self.input_size,
211207
attr_config.rank,
212208
config=self.config,
213209
input_is_parallel=True,
214210
skip_bias_add=True,
215211
bias=False,
216-
init_method=lora_a_init,
212+
init_method=attr_config.lora_a_init,
217213
)
218214

219215
lora_b = nn.Linear(
@@ -222,7 +218,7 @@ def update_layer_lora(
222218
bias=False,
223219
)
224220
with torch.no_grad():
225-
lora_b_init(lora_b.weight)
221+
attr_config.lora_b_init(lora_b.weight)
226222

227223
self._register_adapter_with_device(
228224
adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale, attr_config.enable

tests/gpu/torch/peft/test_megatron_peft.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import copy
12
from functools import partial
23

34
import pytest
45
import torch
6+
import torch.nn.init as init
57
from _test_utils.import_helper import skip_if_no_megatron
68
from _test_utils.torch_dist.dist_utils import get_device_counts, spawn_multiprocess_job
79
from _test_utils.torch_dist.plugins.megatron_common import (
@@ -51,8 +53,6 @@
5153
"*": {
5254
"rank": 32,
5355
"scale": 1,
54-
"lora_a_init": "kaiming_init",
55-
"lora_b_init": "zero_init",
5656
"enable": True,
5757
},
5858
"*output_layer*": {"enable": False},
@@ -66,8 +66,6 @@
6666
"*": {
6767
"rank": 128,
6868
"scale": 1,
69-
"lora_a_init": "kaiming_init",
70-
"lora_b_init": "zero_init",
7169
"enable": True,
7270
},
7371
"*output_layer*": {"enable": False},
@@ -81,8 +79,8 @@
8179
"*": {
8280
"rank": 32,
8381
"scale": 1,
84-
"lora_a_init": "kaiming_init",
85-
"lora_b_init": "kaiming_init",
82+
"lora_a_init": init.kaiming_uniform_,
83+
"lora_b_init": init.kaiming_uniform_,
8684
"enable": True,
8785
},
8886
"*output_layer*": {"enable": False},
@@ -96,8 +94,8 @@
9694
"*": {
9795
"rank": 128,
9896
"scale": 1,
99-
"lora_a_init": "kaiming_init",
100-
"lora_b_init": "kaiming_init",
97+
"lora_a_init": init.kaiming_uniform_,
98+
"lora_b_init": init.kaiming_uniform_,
10199
"enable": True,
102100
},
103101
"*output_layer*": {"enable": False},
@@ -111,8 +109,8 @@
111109
"*": {
112110
"rank": 8,
113111
"scale": 1,
114-
"lora_a_init": "kaiming_init",
115-
"lora_b_init": "kaiming_init",
112+
"lora_a_init": init.kaiming_uniform_,
113+
"lora_b_init": init.kaiming_uniform_,
116114
"enable": True,
117115
},
118116
"*output_layer*": {"enable": False},
@@ -127,8 +125,6 @@
127125
"*self_attention*": {
128126
"rank": 16,
129127
"scale": 1,
130-
"lora_a_init": "kaiming_init",
131-
"lora_b_init": "zero_init",
132128
"enable": True,
133129
},
134130
"*output_layer*": {"enable": False},
@@ -449,14 +445,15 @@ def test_adapter_gradient_flow_freeze_base_model(device_count, lora_config, tmp_
449445

450446
def _test_adapter_gradient_flow_freeze_lora_model(lora_config, tmp_path, rank, size):
451447
hidden_size = 512
452-
lora_config["freeze_lora_weights"] = True
453-
lora_config["freeze_base_model"] = False
448+
local_cfg = copy.deepcopy(lora_config)
449+
local_cfg["freeze_lora_weights"] = True
450+
local_cfg["freeze_base_model"] = False
454451

455452
initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1)
456453
model = _gpt_model_provider(tp_size=size, hidden_size=hidden_size)
457454
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
458455

459-
mtpf.update_model(model, lora_config)
456+
mtpf.update_model(model, local_cfg)
460457
model.train()
461458

462459
# Use a simple forward pass instead for grad check
@@ -569,7 +566,7 @@ def forward_func(mod):
569566
assert hasattr(module.weight_quantizer, "amax")
570567
assert getattr(module.input_quantizer, "amax") is not None
571568
assert getattr(module.weight_quantizer, "amax") is not None
572-
# Check if the lora have teh quantizer, they should not have them.
569+
# Check if the lora have the quantizer, they should not have them.
573570
for adapter_name in module._lora_adapters:
574571
lora_a = module._lora_adapters[adapter_name]["lora_a"]
575572
lora_b = module._lora_adapters[adapter_name]["lora_b"]
@@ -621,7 +618,7 @@ def forward_func(mod):
621618
assert hasattr(module.weight_quantizer, "amax")
622619
assert getattr(module.input_quantizer, "amax") is not None
623620
assert getattr(module.weight_quantizer, "amax") is not None
624-
# Check if the lora have teh quantizer, they should not have them.
621+
# Check if the lora have the quantizer, they should not have them.
625622
for adapter_name in module._lora_adapters:
626623
lora_a = module._lora_adapters[adapter_name]["lora_a"]
627624
lora_b = module._lora_adapters[adapter_name]["lora_b"]
@@ -701,7 +698,7 @@ def forward_func(mod):
701698
assert hasattr(module.weight_quantizer, "amax")
702699
assert getattr(module.input_quantizer, "amax") is not None
703700
assert getattr(module.weight_quantizer, "amax") is not None
704-
# Check if the lora have teh quantizer, they should not have them.
701+
# Check if the lora have the quantizer, they should not have them.
705702
for adapter_name in module._lora_adapters:
706703
lora_a = module._lora_adapters[adapter_name]["lora_a"]
707704
lora_b = module._lora_adapters[adapter_name]["lora_b"]
@@ -765,7 +762,7 @@ def forward_func(mod):
765762
assert hasattr(module.weight_quantizer, "amax")
766763
assert getattr(module.input_quantizer, "amax") is not None
767764
assert getattr(module.weight_quantizer, "amax") is not None
768-
# Check if the lora have teh quantizer, they should not have them.
765+
# Check if the lora have the quantizer, they should not have them.
769766
for adapter_name in module._lora_adapters:
770767
lora_a = module._lora_adapters[adapter_name]["lora_a"]
771768
lora_b = module._lora_adapters[adapter_name]["lora_b"]
@@ -784,7 +781,7 @@ def forward_func(mod):
784781
DEFAULT_LORA_CFG_RANDOM_INIT_TEST,
785782
],
786783
)
787-
def test_mcore_lora_quantize_save_restore(device_count, lora_config, tmp_path):
784+
def test_mcore_lora_then_quantize_save_restore(device_count, lora_config, tmp_path):
788785
spawn_multiprocess_job(
789786
size=device_count,
790787
job=partial(_test_mcore_lora_then_quantize_save_restore, lora_config, str(tmp_path)),

0 commit comments

Comments
 (0)