Skip to content

Commit 1eb6677

Browse files
committed
Update lora implementations
1 parent d9a79c1 commit 1eb6677

File tree

4 files changed

+68
-88
lines changed

4 files changed

+68
-88
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""LoRA (Low-Rank Adaptation) implementation for parameter-efficient fine-tuning."""
22

3-
from . import layer, tp_layer
3+
from . import layer, plugins
File renamed without changes.

modelopt/torch/peft/lora/tp_layer.py renamed to modelopt/torch/peft/lora/plugins/megatron.py

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,80 @@
1-
"""Tensor Parallel LoRA implementations for Megatron layers."""
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Megatron-Core specific PEFT/LoRA plugins."""
217

318
import math
419
from collections.abc import Callable
520

21+
import torch
622
import torch.nn as nn
723
import torch.nn.init as init
824
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
925
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
1026

11-
from ..config import PEFTAttributeConfig
12-
from .layer import LoRAModule, LoRAModuleRegistry
27+
from ...config import PEFTAttributeConfig
28+
from ..layer import LoRAModule, LoRAModuleRegistry
1329

1430
try:
31+
from megatron.core.transformer.module import MegatronModule
32+
1533
from modelopt.torch.quantization.plugins.megatron import (
1634
_MegatronColumnParallelLinear as QuantColumnParallelLinear,
1735
)
1836
from modelopt.torch.quantization.plugins.megatron import (
1937
_MegatronRowParallelLinear as QuantRowParallelLinear,
2038
)
2139

22-
QUANT_MODULES_AVAILABLE = True
40+
MEGATRON_AVAILABLE = True
2341
except ImportError:
24-
QUANT_MODULES_AVAILABLE = False
42+
MegatronModule = None
43+
MEGATRON_AVAILABLE = False
44+
45+
from ...custom import CUSTOM_MODEL_PLUGINS
2546

2647
DEFAULT_LORA_RANK = 64
2748
DEFAULT_SCALE = 1.0
2849

50+
__all__ = []
51+
52+
53+
def megatron_replace_lora_module_hook(model: torch.nn.Module):
54+
"""Configure Megatron-Core model PEFT/LoRA support.
55+
56+
This callback is called before the LoRAModule replacement to configure
57+
distributed checkpointing support. For each MegatronModule:
58+
1. We enable heterogeneous distributed checkpointing
59+
60+
Note: LoRAModule already has built-in get_extra_state and set_extra_state methods,
61+
so we don't need to register callbacks for them.
62+
"""
63+
if not MEGATRON_AVAILABLE:
64+
return
65+
66+
for name, module in model.named_modules():
67+
if isinstance(module, MegatronModule):
68+
# Enable heterogeneous distributed checkpointing
69+
if hasattr(module, "config") and hasattr(
70+
module.config, "hetereogenous_dist_checkpoint"
71+
):
72+
module.config.hetereogenous_dist_checkpoint = True
73+
74+
75+
# Register the hook
76+
CUSTOM_MODEL_PLUGINS.add(megatron_replace_lora_module_hook)
77+
2978

3079
class _MegatronParallelLoRABase(LoRAModule):
3180
"""Base class for Megatron tensor parallel LoRA implementations.
@@ -107,22 +156,20 @@ def update_layer_lora(
107156
adapter_name: Name for the new adapter
108157
rank: Rank of the LoRA decomposition
109158
"""
110-
lora_a = ColumnParallelLinear(
111-
self.input_size,
112-
attr_config.rank,
113-
config=self.config,
159+
lora_a = nn.Linear(
160+
in_features=self.input_size,
161+
out_features=attr_config.rank,
114162
bias=False,
115-
gather_output=True,
116-
init_method=attr_config.lora_a_init,
117-
disable_grad_reduce=getattr(self.config, "sequence_parallel", False),
118163
)
164+
with torch.no_grad():
165+
attr_config.lora_b_init(lora_a.weight) # type: ignore[misc]
119166

120167
lora_b = ColumnParallelLinear(
121168
attr_config.rank,
122169
self.output_size,
123170
config=self.config,
124171
bias=False,
125-
gather_output=False, # Keep output distributed like base layer
172+
gather_output=False,
126173
init_method=attr_config.lora_a_init,
127174
)
128175

@@ -144,11 +191,8 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
144191
state_dict = self.state_dict(prefix="", keep_vars=True)
145192

146193
for adapter_name in self._lora_adapters:
147-
lora_a_key = f"lora_a_{adapter_name}.weight"
148194
lora_b_key = f"lora_b_{adapter_name}.weight"
149195

150-
if lora_a_key in state_dict:
151-
lora_state_dict[lora_a_key] = state_dict[lora_a_key]
152196
if lora_b_key in state_dict:
153197
lora_state_dict[lora_b_key] = state_dict[lora_b_key]
154198

@@ -194,14 +238,13 @@ def update_layer_lora(
194238
init_method=attr_config.lora_a_init,
195239
)
196240

197-
lora_b = ColumnParallelLinear(
198-
attr_config.rank,
199-
self.output_size,
200-
config=self.config,
241+
lora_b = nn.Linear(
242+
in_features=attr_config.rank,
243+
out_features=self.output_size,
201244
bias=False,
202-
gather_output=True,
203-
init_method=attr_config.lora_b_init,
204245
)
246+
with torch.no_grad():
247+
attr_config.lora_b_init(lora_b.weight) # type: ignore[misc]
205248

206249
self._register_adapter_with_device(
207250
adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale, attr_config.enable
@@ -222,19 +265,13 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
222265

223266
for adapter_name in self._lora_adapters:
224267
lora_a_key = f"lora_a_{adapter_name}.weight"
225-
lora_b_key = f"lora_b_{adapter_name}.weight"
226268

227269
if lora_a_key in state_dict:
228270
lora_state_dict[lora_a_key] = state_dict[lora_a_key]
229-
if lora_b_key in state_dict:
230-
lora_state_dict[lora_b_key] = state_dict[lora_b_key]
231271

232272
lora_sharding_dims = {}
233273
for key in lora_state_dict:
234-
if "lora_a_" in key:
235-
lora_sharding_dims[key] = 1
236-
elif "lora_b_" in key:
237-
lora_sharding_dims[key] = 0
274+
lora_sharding_dims[key] = 1
238275

239276
if lora_state_dict:
240277
lora_sharded = make_sharded_tensors_for_checkpoint(
@@ -246,7 +283,7 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
246283

247284

248285
# Register quantized versions if available
249-
if QUANT_MODULES_AVAILABLE:
286+
if MEGATRON_AVAILABLE:
250287
LoRAModuleRegistry.register({QuantColumnParallelLinear: "quant_megatron_ColumnParallelLinear"})(
251288
_LoRAMegatronColumnParallelLinear
252289
)

modelopt/torch/peft/plugins/megatron.py

Lines changed: 0 additions & 57 deletions
This file was deleted.

0 commit comments

Comments
 (0)