Skip to content

Commit bf33ae4

Browse files
committed
Update: to support quantize the lora layers
Signed-off-by: Jingyu Xin <[email protected]>
1 parent ed63e47 commit bf33ae4

File tree

3 files changed

+45
-13
lines changed

3 files changed

+45
-13
lines changed

modelopt/torch/peft/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from . import mode
1919
from .config import *
20+
from .conversion import *
2021
from .convert import *
2122

2223
# isort: off

modelopt/torch/peft/lora/tp_layer.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,8 @@ def _get_init_methods(self) -> tuple[Callable, Callable]:
3838
Returns:
3939
Tuple of (lora_a_init, lora_b_init) initialization functions
4040
"""
41-
# LoRA A uses Kaiming uniform initialization
42-
lora_a_init = lambda weight: init.kaiming_uniform_(weight, a=math.sqrt(5))
43-
# LoRA B is initialized to zero for stable training start
44-
lora_b_init = lambda weight: init.zeros_(weight)
41+
lora_a_init = lambda weight: init.kaiming_uniform_(weight, a=math.sqrt(5)) # noqa: E731 # LoRA A: Kaiming uniform
42+
lora_b_init = lambda weight: init.zeros_(weight) # noqa: E731 # LoRA B: zeros
4543
return lora_a_init, lora_b_init
4644

4745
def _register_adapter_with_device(
@@ -81,7 +79,7 @@ def _register_adapter_with_device(
8179

8280

8381
@LoRAModuleRegistry.register({ColumnParallelLinear: "megatron_ColumnParallelLinear"})
84-
class _MegatronColumnParallelLinear(_MegatronParallelLoRABase):
82+
class _LoRAMegatronColumnParallelLinear(_MegatronParallelLoRABase):
8583
"""LoRA implementation for Megatron ColumnParallelLinear layers.
8684
8785
This implementation creates column-parallel LoRA adapters that match
@@ -124,7 +122,7 @@ def update_layer_lora(
124122

125123

126124
@LoRAModuleRegistry.register({RowParallelLinear: "megatron_RowParallelLinear"})
127-
class _MegatronRowParallelLinear(_MegatronParallelLoRABase):
125+
class _LoRAMegatronRowParallelLinear(_MegatronParallelLoRABase):
128126
"""LoRA implementation for Megatron RowParallelLinear layers.
129127
130128
This implementation creates row-parallel LoRA adapters that match
@@ -170,8 +168,42 @@ def update_layer_lora(
170168
if QUANT_MODULES_AVAILABLE:
171169
# Register the same LoRA implementations for quantized modules
172170
LoRAModuleRegistry.register({QuantColumnParallelLinear: "quant_megatron_ColumnParallelLinear"})(
173-
_MegatronColumnParallelLinear
171+
_LoRAMegatronColumnParallelLinear
174172
)
175173
LoRAModuleRegistry.register({QuantRowParallelLinear: "quant_megatron_RowParallelLinear"})(
176-
_MegatronRowParallelLinear
174+
_LoRAMegatronRowParallelLinear
177175
)
176+
177+
from modelopt.torch.quantization.nn import QuantModuleRegistry
178+
179+
class _QuantLoRAMegatronColumnParallelLinear(
180+
_LoRAMegatronColumnParallelLinear, QuantColumnParallelLinear
181+
):
182+
"""Quantized LoRA ColumnParallelLinear that combines LoRA and quantization.
183+
184+
This class ensures that the base layer functionality is quantized while
185+
preserving LoRA adapter functionality.
186+
"""
187+
188+
def _setup(self):
189+
QuantColumnParallelLinear._setup(self)
190+
191+
class _QuantLoRAMegatronRowParallelLinear(
192+
_LoRAMegatronRowParallelLinear, QuantRowParallelLinear
193+
):
194+
"""Quantized LoRA RowParallelLinear that combines LoRA and quantization.
195+
196+
This class ensures that the base layer functionality is quantized while
197+
preserving LoRA adapter functionality.
198+
"""
199+
200+
def _setup(self):
201+
QuantRowParallelLinear._setup(self)
202+
203+
# Register LoRA modules in QuantModuleRegistry so they can be quantized
204+
QuantModuleRegistry.register(
205+
{_LoRAMegatronColumnParallelLinear: "lora_megatron_ColumnParallelLinear"}
206+
)(_QuantLoRAMegatronColumnParallelLinear)
207+
QuantModuleRegistry.register(
208+
{_LoRAMegatronRowParallelLinear: "lora_megatron_RowParallelLinear"}
209+
)(_QuantLoRAMegatronRowParallelLinear)

modelopt/torch/peft/plugins/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515

1616
"""PEFT/LoRA plugins for various frameworks."""
1717

18-
# Import plugins to register them
19-
try:
20-
from . import megatron
21-
except ImportError:
22-
pass # Megatron not available
18+
from contextlib import suppress
19+
20+
with suppress(ImportError):
21+
from . import megatron as _megatron

0 commit comments

Comments
 (0)