15
15
16
16
"""Megatron-Core specific PEFT/LoRA plugins."""
17
17
18
- import math
19
- from collections .abc import Callable
20
-
21
18
import torch
22
19
import torch .nn as nn
23
- import torch .nn .init as init
24
20
from megatron .core .tensor_parallel .layers import ColumnParallelLinear , RowParallelLinear
25
21
from megatron .core .transformer .module import MegatronModule
26
22
from megatron .core .transformer .utils import make_sharded_tensors_for_checkpoint
35
31
36
32
from ...config import PEFTAttributeConfig
37
33
from ...custom import CUSTOM_MODEL_PLUGINS
38
- from ..layer import LoRAModule , LoRAModuleRegistry
34
+ from ..layer import LoRAModule , LoRAModuleRegistry , get_init_methods
39
35
40
36
DEFAULT_LORA_RANK = 64
41
37
DEFAULT_SCALE = 1.0
@@ -73,18 +69,6 @@ class _MegatronParallelLoRABase(LoRAModule):
73
69
LoRA implementations, reducing code duplication.
74
70
"""
75
71
76
- def _get_init_methods (self , lora_a_init , lora_b_init ) -> tuple [Callable , Callable ]:
77
- """Get initialization methods for LoRA A and B matrices.
78
-
79
- Returns:
80
- Tuple of (lora_a_init, lora_b_init) initialization functions
81
- """
82
- if lora_a_init is None :
83
- lora_a_init = lambda weight : init .kaiming_uniform_ (weight , a = math .sqrt (5 )) # noqa: E731 # LoRA A: Kaiming uniform
84
- if lora_b_init is None :
85
- lora_b_init = lambda weight : init .zeros_ (weight ) # noqa: E731 # LoRA B: zeros
86
- return lora_a_init , lora_b_init
87
-
88
72
def _register_adapter_with_device (
89
73
self ,
90
74
adapter_name : str ,
@@ -146,21 +130,23 @@ def update_layer_lora(
146
130
adapter_name: Name for the new adapter
147
131
rank: Rank of the LoRA decomposition
148
132
"""
133
+ lora_a_init = get_init_methods (attr_config .lora_a_init )
134
+ lora_b_init = get_init_methods (attr_config .lora_b_init )
149
135
lora_a = nn .Linear (
150
136
in_features = self .input_size ,
151
137
out_features = attr_config .rank ,
152
138
bias = False ,
153
139
)
154
140
with torch .no_grad ():
155
- attr_config . lora_b_init (lora_a .weight ) # type: ignore[misc]
141
+ lora_a_init (lora_a .weight )
156
142
157
143
lora_b = ColumnParallelLinear (
158
144
attr_config .rank ,
159
145
self .output_size ,
160
146
config = self .config ,
161
147
bias = False ,
162
148
gather_output = False ,
163
- init_method = attr_config . lora_a_init ,
149
+ init_method = lora_b_init ,
164
150
)
165
151
166
152
self ._register_adapter_with_device (
@@ -218,14 +204,16 @@ def update_layer_lora(
218
204
adapter_name: Name for the new adapter
219
205
rank: Rank of the LoRA decomposition
220
206
"""
207
+ lora_a_init = get_init_methods (attr_config .lora_a_init )
208
+ lora_b_init = get_init_methods (attr_config .lora_b_init )
221
209
lora_a = RowParallelLinear (
222
210
self .input_size ,
223
211
attr_config .rank ,
224
212
config = self .config ,
225
213
input_is_parallel = True ,
226
214
skip_bias_add = True ,
227
215
bias = False ,
228
- init_method = attr_config . lora_a_init ,
216
+ init_method = lora_a_init ,
229
217
)
230
218
231
219
lora_b = nn .Linear (
@@ -234,7 +222,7 @@ def update_layer_lora(
234
222
bias = False ,
235
223
)
236
224
with torch .no_grad ():
237
- attr_config . lora_b_init (lora_b .weight ) # type: ignore[misc]
225
+ lora_b_init (lora_b .weight )
238
226
239
227
self ._register_adapter_with_device (
240
228
adapter_name , lora_a , lora_b , attr_config .rank , attr_config .scale , attr_config .enable
0 commit comments