@@ -38,10 +38,8 @@ def _get_init_methods(self) -> tuple[Callable, Callable]:
38
38
Returns:
39
39
Tuple of (lora_a_init, lora_b_init) initialization functions
40
40
"""
41
- # LoRA A uses Kaiming uniform initialization
42
- lora_a_init = lambda weight : init .kaiming_uniform_ (weight , a = math .sqrt (5 ))
43
- # LoRA B is initialized to zero for stable training start
44
- lora_b_init = lambda weight : init .zeros_ (weight )
41
+ lora_a_init = lambda weight : init .kaiming_uniform_ (weight , a = math .sqrt (5 )) # noqa: E731 # LoRA A: Kaiming uniform
42
+ lora_b_init = lambda weight : init .zeros_ (weight ) # noqa: E731 # LoRA B: zeros
45
43
return lora_a_init , lora_b_init
46
44
47
45
def _register_adapter_with_device (
@@ -81,7 +79,7 @@ def _register_adapter_with_device(
81
79
82
80
83
81
@LoRAModuleRegistry .register ({ColumnParallelLinear : "megatron_ColumnParallelLinear" })
84
- class _MegatronColumnParallelLinear (_MegatronParallelLoRABase ):
82
+ class _LoRAMegatronColumnParallelLinear (_MegatronParallelLoRABase ):
85
83
"""LoRA implementation for Megatron ColumnParallelLinear layers.
86
84
87
85
This implementation creates column-parallel LoRA adapters that match
@@ -124,7 +122,7 @@ def update_layer_lora(
124
122
125
123
126
124
@LoRAModuleRegistry .register ({RowParallelLinear : "megatron_RowParallelLinear" })
127
- class _MegatronRowParallelLinear (_MegatronParallelLoRABase ):
125
+ class _LoRAMegatronRowParallelLinear (_MegatronParallelLoRABase ):
128
126
"""LoRA implementation for Megatron RowParallelLinear layers.
129
127
130
128
This implementation creates row-parallel LoRA adapters that match
@@ -170,8 +168,42 @@ def update_layer_lora(
170
168
if QUANT_MODULES_AVAILABLE :
171
169
# Register the same LoRA implementations for quantized modules
172
170
LoRAModuleRegistry .register ({QuantColumnParallelLinear : "quant_megatron_ColumnParallelLinear" })(
173
- _MegatronColumnParallelLinear
171
+ _LoRAMegatronColumnParallelLinear
174
172
)
175
173
LoRAModuleRegistry .register ({QuantRowParallelLinear : "quant_megatron_RowParallelLinear" })(
176
- _MegatronRowParallelLinear
174
+ _LoRAMegatronRowParallelLinear
177
175
)
176
+
177
+ from modelopt .torch .quantization .nn import QuantModuleRegistry
178
+
179
+ class _QuantLoRAMegatronColumnParallelLinear (
180
+ _LoRAMegatronColumnParallelLinear , QuantColumnParallelLinear
181
+ ):
182
+ """Quantized LoRA ColumnParallelLinear that combines LoRA and quantization.
183
+
184
+ This class ensures that the base layer functionality is quantized while
185
+ preserving LoRA adapter functionality.
186
+ """
187
+
188
+ def _setup (self ):
189
+ QuantColumnParallelLinear ._setup (self )
190
+
191
+ class _QuantLoRAMegatronRowParallelLinear (
192
+ _LoRAMegatronRowParallelLinear , QuantRowParallelLinear
193
+ ):
194
+ """Quantized LoRA RowParallelLinear that combines LoRA and quantization.
195
+
196
+ This class ensures that the base layer functionality is quantized while
197
+ preserving LoRA adapter functionality.
198
+ """
199
+
200
+ def _setup (self ):
201
+ QuantRowParallelLinear ._setup (self )
202
+
203
+ # Register LoRA modules in QuantModuleRegistry so they can be quantized
204
+ QuantModuleRegistry .register (
205
+ {_LoRAMegatronColumnParallelLinear : "lora_megatron_ColumnParallelLinear" }
206
+ )(_QuantLoRAMegatronColumnParallelLinear )
207
+ QuantModuleRegistry .register (
208
+ {_LoRAMegatronRowParallelLinear : "lora_megatron_RowParallelLinear" }
209
+ )(_QuantLoRAMegatronRowParallelLinear )
0 commit comments