@@ -38,10 +38,8 @@ def _get_init_methods(self) -> tuple[Callable, Callable]:
3838 Returns:
3939 Tuple of (lora_a_init, lora_b_init) initialization functions
4040 """
41- # LoRA A uses Kaiming uniform initialization
42- lora_a_init = lambda weight : init .kaiming_uniform_ (weight , a = math .sqrt (5 ))
43- # LoRA B is initialized to zero for stable training start
44- lora_b_init = lambda weight : init .zeros_ (weight )
41+ lora_a_init = lambda weight : init .kaiming_uniform_ (weight , a = math .sqrt (5 )) # noqa: E731 # LoRA A: Kaiming uniform
42+ lora_b_init = lambda weight : init .zeros_ (weight ) # noqa: E731 # LoRA B: zeros
4543 return lora_a_init , lora_b_init
4644
4745 def _register_adapter_with_device (
@@ -81,7 +79,7 @@ def _register_adapter_with_device(
8179
8280
8381@LoRAModuleRegistry .register ({ColumnParallelLinear : "megatron_ColumnParallelLinear" })
84- class _MegatronColumnParallelLinear (_MegatronParallelLoRABase ):
82+ class _LoRAMegatronColumnParallelLinear (_MegatronParallelLoRABase ):
8583 """LoRA implementation for Megatron ColumnParallelLinear layers.
8684
8785 This implementation creates column-parallel LoRA adapters that match
@@ -124,7 +122,7 @@ def update_layer_lora(
124122
125123
126124@LoRAModuleRegistry .register ({RowParallelLinear : "megatron_RowParallelLinear" })
127- class _MegatronRowParallelLinear (_MegatronParallelLoRABase ):
125+ class _LoRAMegatronRowParallelLinear (_MegatronParallelLoRABase ):
128126 """LoRA implementation for Megatron RowParallelLinear layers.
129127
130128 This implementation creates row-parallel LoRA adapters that match
@@ -170,8 +168,42 @@ def update_layer_lora(
170168if QUANT_MODULES_AVAILABLE :
171169 # Register the same LoRA implementations for quantized modules
172170 LoRAModuleRegistry .register ({QuantColumnParallelLinear : "quant_megatron_ColumnParallelLinear" })(
173- _MegatronColumnParallelLinear
171+ _LoRAMegatronColumnParallelLinear
174172 )
175173 LoRAModuleRegistry .register ({QuantRowParallelLinear : "quant_megatron_RowParallelLinear" })(
176- _MegatronRowParallelLinear
174+ _LoRAMegatronRowParallelLinear
177175 )
176+
177+ from modelopt .torch .quantization .nn import QuantModuleRegistry
178+
179+ class _QuantLoRAMegatronColumnParallelLinear (
180+ _LoRAMegatronColumnParallelLinear , QuantColumnParallelLinear
181+ ):
182+ """Quantized LoRA ColumnParallelLinear that combines LoRA and quantization.
183+
184+ This class ensures that the base layer functionality is quantized while
185+ preserving LoRA adapter functionality.
186+ """
187+
188+ def _setup (self ):
189+ QuantColumnParallelLinear ._setup (self )
190+
191+ class _QuantLoRAMegatronRowParallelLinear (
192+ _LoRAMegatronRowParallelLinear , QuantRowParallelLinear
193+ ):
194+ """Quantized LoRA RowParallelLinear that combines LoRA and quantization.
195+
196+ This class ensures that the base layer functionality is quantized while
197+ preserving LoRA adapter functionality.
198+ """
199+
200+ def _setup (self ):
201+ QuantRowParallelLinear ._setup (self )
202+
203+ # Register LoRA modules in QuantModuleRegistry so they can be quantized
204+ QuantModuleRegistry .register (
205+ {_LoRAMegatronColumnParallelLinear : "lora_megatron_ColumnParallelLinear" }
206+ )(_QuantLoRAMegatronColumnParallelLinear )
207+ QuantModuleRegistry .register (
208+ {_LoRAMegatronRowParallelLinear : "lora_megatron_RowParallelLinear" }
209+ )(_QuantLoRAMegatronRowParallelLinear )
0 commit comments