1
- """Tensor Parallel LoRA implementations for Megatron layers."""
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Megatron-Core specific PEFT/LoRA plugins."""
2
17
3
18
import math
4
19
from collections .abc import Callable
5
20
21
+ import torch
6
22
import torch .nn as nn
7
23
import torch .nn .init as init
8
24
from megatron .core .tensor_parallel .layers import ColumnParallelLinear , RowParallelLinear
9
25
from megatron .core .transformer .utils import make_sharded_tensors_for_checkpoint
10
26
11
- from ..config import PEFTAttributeConfig
12
- from .layer import LoRAModule , LoRAModuleRegistry
27
+ from ... config import PEFTAttributeConfig
28
+ from .. layer import LoRAModule , LoRAModuleRegistry
13
29
14
30
try :
31
+ from megatron .core .transformer .module import MegatronModule
32
+
15
33
from modelopt .torch .quantization .plugins .megatron import (
16
34
_MegatronColumnParallelLinear as QuantColumnParallelLinear ,
17
35
)
18
36
from modelopt .torch .quantization .plugins .megatron import (
19
37
_MegatronRowParallelLinear as QuantRowParallelLinear ,
20
38
)
21
39
22
- QUANT_MODULES_AVAILABLE = True
40
+ MEGATRON_AVAILABLE = True
23
41
except ImportError :
24
- QUANT_MODULES_AVAILABLE = False
42
+ MegatronModule = None
43
+ MEGATRON_AVAILABLE = False
44
+
45
+ from ...custom import CUSTOM_MODEL_PLUGINS
25
46
26
47
DEFAULT_LORA_RANK = 64
27
48
DEFAULT_SCALE = 1.0
28
49
50
+ __all__ = []
51
+
52
+
53
+ def megatron_replace_lora_module_hook (model : torch .nn .Module ):
54
+ """Configure Megatron-Core model PEFT/LoRA support.
55
+
56
+ This callback is called before the LoRAModule replacement to configure
57
+ distributed checkpointing support. For each MegatronModule:
58
+ 1. We enable heterogeneous distributed checkpointing
59
+
60
+ Note: LoRAModule already has built-in get_extra_state and set_extra_state methods,
61
+ so we don't need to register callbacks for them.
62
+ """
63
+ if not MEGATRON_AVAILABLE :
64
+ return
65
+
66
+ for name , module in model .named_modules ():
67
+ if isinstance (module , MegatronModule ):
68
+ # Enable heterogeneous distributed checkpointing
69
+ if hasattr (module , "config" ) and hasattr (
70
+ module .config , "hetereogenous_dist_checkpoint"
71
+ ):
72
+ module .config .hetereogenous_dist_checkpoint = True
73
+
74
+
75
+ # Register the hook
76
+ CUSTOM_MODEL_PLUGINS .add (megatron_replace_lora_module_hook )
77
+
29
78
30
79
class _MegatronParallelLoRABase (LoRAModule ):
31
80
"""Base class for Megatron tensor parallel LoRA implementations.
@@ -107,22 +156,20 @@ def update_layer_lora(
107
156
adapter_name: Name for the new adapter
108
157
rank: Rank of the LoRA decomposition
109
158
"""
110
- lora_a = ColumnParallelLinear (
111
- self .input_size ,
112
- attr_config .rank ,
113
- config = self .config ,
159
+ lora_a = nn .Linear (
160
+ in_features = self .input_size ,
161
+ out_features = attr_config .rank ,
114
162
bias = False ,
115
- gather_output = True ,
116
- init_method = attr_config .lora_a_init ,
117
- disable_grad_reduce = getattr (self .config , "sequence_parallel" , False ),
118
163
)
164
+ with torch .no_grad ():
165
+ attr_config .lora_b_init (lora_a .weight ) # type: ignore[misc]
119
166
120
167
lora_b = ColumnParallelLinear (
121
168
attr_config .rank ,
122
169
self .output_size ,
123
170
config = self .config ,
124
171
bias = False ,
125
- gather_output = False , # Keep output distributed like base layer
172
+ gather_output = False ,
126
173
init_method = attr_config .lora_a_init ,
127
174
)
128
175
@@ -144,11 +191,8 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
144
191
state_dict = self .state_dict (prefix = "" , keep_vars = True )
145
192
146
193
for adapter_name in self ._lora_adapters :
147
- lora_a_key = f"lora_a_{ adapter_name } .weight"
148
194
lora_b_key = f"lora_b_{ adapter_name } .weight"
149
195
150
- if lora_a_key in state_dict :
151
- lora_state_dict [lora_a_key ] = state_dict [lora_a_key ]
152
196
if lora_b_key in state_dict :
153
197
lora_state_dict [lora_b_key ] = state_dict [lora_b_key ]
154
198
@@ -194,14 +238,13 @@ def update_layer_lora(
194
238
init_method = attr_config .lora_a_init ,
195
239
)
196
240
197
- lora_b = ColumnParallelLinear (
198
- attr_config .rank ,
199
- self .output_size ,
200
- config = self .config ,
241
+ lora_b = nn .Linear (
242
+ in_features = attr_config .rank ,
243
+ out_features = self .output_size ,
201
244
bias = False ,
202
- gather_output = True ,
203
- init_method = attr_config .lora_b_init ,
204
245
)
246
+ with torch .no_grad ():
247
+ attr_config .lora_b_init (lora_b .weight ) # type: ignore[misc]
205
248
206
249
self ._register_adapter_with_device (
207
250
adapter_name , lora_a , lora_b , attr_config .rank , attr_config .scale , attr_config .enable
@@ -222,19 +265,13 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
222
265
223
266
for adapter_name in self ._lora_adapters :
224
267
lora_a_key = f"lora_a_{ adapter_name } .weight"
225
- lora_b_key = f"lora_b_{ adapter_name } .weight"
226
268
227
269
if lora_a_key in state_dict :
228
270
lora_state_dict [lora_a_key ] = state_dict [lora_a_key ]
229
- if lora_b_key in state_dict :
230
- lora_state_dict [lora_b_key ] = state_dict [lora_b_key ]
231
271
232
272
lora_sharding_dims = {}
233
273
for key in lora_state_dict :
234
- if "lora_a_" in key :
235
- lora_sharding_dims [key ] = 1
236
- elif "lora_b_" in key :
237
- lora_sharding_dims [key ] = 0
274
+ lora_sharding_dims [key ] = 1
238
275
239
276
if lora_state_dict :
240
277
lora_sharded = make_sharded_tensors_for_checkpoint (
@@ -246,7 +283,7 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
246
283
247
284
248
285
# Register quantized versions if available
249
- if QUANT_MODULES_AVAILABLE :
286
+ if MEGATRON_AVAILABLE :
250
287
LoRAModuleRegistry .register ({QuantColumnParallelLinear : "quant_megatron_ColumnParallelLinear" })(
251
288
_LoRAMegatronColumnParallelLinear
252
289
)
0 commit comments