Skip to content

Commit 81f8d06

Browse files
committed
Update init functions
Signed-off-by: Jingyu Xin <[email protected]>
1 parent ce6bead commit 81f8d06

File tree

3 files changed

+31
-58
lines changed

3 files changed

+31
-58
lines changed

modelopt/torch/peft/config.py

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,10 @@
1616
"""Configuration classes for PEFT methods."""
1717

1818
import math
19-
import pickle # nosec B403 - Only checking picklability
2019
from collections.abc import Callable
2120

2221
import torch.nn.init as init
23-
from pydantic import field_validator, model_validator
22+
from pydantic import field_validator
2423

2524
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
2625

@@ -61,14 +60,14 @@ class PEFTAttributeConfig(ModeloptBaseConfig):
6160
description="Scaling factor for the LoRA output. Controls the magnitude of the adaptation.",
6261
)
6362

64-
lora_a_init: Callable[[object], None] | None = ModeloptField(
65-
default=kaiming_init,
63+
lora_a_init: str = ModeloptField(
64+
default="kaiming_init",
6665
title="LoRA A matrix initializer",
6766
description="Custom initialization function for LoRA A matrix. Default to Kaiming uniform initialization.",
6867
)
6968

70-
lora_b_init: Callable[[object], None] | None = ModeloptField(
71-
default=zero_init,
69+
lora_b_init: str = ModeloptField(
70+
default="zero_init",
7271
title="LoRA B matrix initializer",
7372
description="Custom initialization function for LoRA B matrix. Default to zero initialization.",
7473
)
@@ -89,33 +88,6 @@ def validate_scale(cls, v):
8988
raise ValueError("scale must be a positive number")
9089
return v
9190

92-
@model_validator(mode="after")
93-
def validate_init_functions(self):
94-
"""Validate initialization functions are callable and picklable."""
95-
if self.lora_a_init is not None and not callable(self.lora_a_init):
96-
raise ValueError("lora_a_init must be callable")
97-
if self.lora_b_init is not None and not callable(self.lora_b_init):
98-
raise ValueError("lora_b_init must be callable")
99-
if self.lora_a_init is not None:
100-
try:
101-
_del = pickle.dumps(self.lora_a_init)
102-
del _del
103-
except (pickle.PicklingError, TypeError, AttributeError) as e:
104-
raise ValueError(
105-
f"lora_a_init cannot be pickled: {e}. "
106-
"Please use a module-level function instead of a lambda or nested function."
107-
)
108-
if self.lora_b_init is not None:
109-
try:
110-
_del = pickle.dumps(self.lora_b_init)
111-
del _del
112-
except (pickle.PicklingError, TypeError, AttributeError) as e:
113-
raise ValueError(
114-
f"lora_b_init cannot be pickled: {e}. "
115-
"Please use a module-level function instead of a lambda or nested function."
116-
)
117-
return self
118-
11991

12092
# Type alias for adapter configuration
12193
PEFTAdapterCfgType = dict[str | Callable, PEFTAttributeConfig | dict]

modelopt/torch/peft/lora/layer.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,33 @@
11
"""LoRA (Low-Rank Adaptation) module implementation."""
22

3+
import math
34
import warnings
45
from abc import abstractmethod
56
from typing import Any
67

78
import torch
89
import torch.nn as nn
10+
import torch.nn.init as init
911

1012
from modelopt.torch.opt.dynamic import DynamicModule, _DMRegistryCls
1113

1214
from ..config import PEFTAttributeConfig
1315

14-
__all__ = [
15-
"LoRAModule",
16-
"LoRAModuleRegistry",
17-
]
16+
__all__ = ["LoRAModule", "LoRAModuleRegistry", "get_init_methods"]
17+
18+
19+
def get_init_methods(init_method: str = "kaiming_init"):
20+
"""Get the target init method for the lora a and lora b weights.
21+
22+
Args:
23+
init_method: the init method you want for the lora layer
24+
"""
25+
if init_method == "kaiming_init":
26+
return lambda weight: init.kaiming_uniform_(
27+
weight, a=math.sqrt(5)
28+
) # LoRA A: Kaiming uniform
29+
elif init_method == "zero_init":
30+
return lambda weight: init.zeros_(weight) # LoRA B: zeros
1831

1932

2033
class LoRAModule(DynamicModule):

modelopt/torch/peft/lora/plugins/megatron.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,8 @@
1515

1616
"""Megatron-Core specific PEFT/LoRA plugins."""
1717

18-
import math
19-
from collections.abc import Callable
20-
2118
import torch
2219
import torch.nn as nn
23-
import torch.nn.init as init
2420
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
2521
from megatron.core.transformer.module import MegatronModule
2622
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
@@ -35,7 +31,7 @@
3531

3632
from ...config import PEFTAttributeConfig
3733
from ...custom import CUSTOM_MODEL_PLUGINS
38-
from ..layer import LoRAModule, LoRAModuleRegistry
34+
from ..layer import LoRAModule, LoRAModuleRegistry, get_init_methods
3935

4036
DEFAULT_LORA_RANK = 64
4137
DEFAULT_SCALE = 1.0
@@ -73,18 +69,6 @@ class _MegatronParallelLoRABase(LoRAModule):
7369
LoRA implementations, reducing code duplication.
7470
"""
7571

76-
def _get_init_methods(self, lora_a_init, lora_b_init) -> tuple[Callable, Callable]:
77-
"""Get initialization methods for LoRA A and B matrices.
78-
79-
Returns:
80-
Tuple of (lora_a_init, lora_b_init) initialization functions
81-
"""
82-
if lora_a_init is None:
83-
lora_a_init = lambda weight: init.kaiming_uniform_(weight, a=math.sqrt(5)) # noqa: E731 # LoRA A: Kaiming uniform
84-
if lora_b_init is None:
85-
lora_b_init = lambda weight: init.zeros_(weight) # noqa: E731 # LoRA B: zeros
86-
return lora_a_init, lora_b_init
87-
8872
def _register_adapter_with_device(
8973
self,
9074
adapter_name: str,
@@ -146,21 +130,23 @@ def update_layer_lora(
146130
adapter_name: Name for the new adapter
147131
rank: Rank of the LoRA decomposition
148132
"""
133+
lora_a_init = get_init_methods(attr_config.lora_a_init)
134+
lora_b_init = get_init_methods(attr_config.lora_b_init)
149135
lora_a = nn.Linear(
150136
in_features=self.input_size,
151137
out_features=attr_config.rank,
152138
bias=False,
153139
)
154140
with torch.no_grad():
155-
attr_config.lora_b_init(lora_a.weight) # type: ignore[misc]
141+
lora_a_init(lora_a.weight)
156142

157143
lora_b = ColumnParallelLinear(
158144
attr_config.rank,
159145
self.output_size,
160146
config=self.config,
161147
bias=False,
162148
gather_output=False,
163-
init_method=attr_config.lora_a_init,
149+
init_method=lora_b_init,
164150
)
165151

166152
self._register_adapter_with_device(
@@ -218,14 +204,16 @@ def update_layer_lora(
218204
adapter_name: Name for the new adapter
219205
rank: Rank of the LoRA decomposition
220206
"""
207+
lora_a_init = get_init_methods(attr_config.lora_a_init)
208+
lora_b_init = get_init_methods(attr_config.lora_b_init)
221209
lora_a = RowParallelLinear(
222210
self.input_size,
223211
attr_config.rank,
224212
config=self.config,
225213
input_is_parallel=True,
226214
skip_bias_add=True,
227215
bias=False,
228-
init_method=attr_config.lora_a_init,
216+
init_method=lora_a_init,
229217
)
230218

231219
lora_b = nn.Linear(
@@ -234,7 +222,7 @@ def update_layer_lora(
234222
bias=False,
235223
)
236224
with torch.no_grad():
237-
attr_config.lora_b_init(lora_b.weight) # type: ignore[misc]
225+
lora_b_init(lora_b.weight)
238226

239227
self._register_adapter_with_device(
240228
adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale, attr_config.enable

0 commit comments

Comments
 (0)