22
22
import torch .nn as nn
23
23
import torch .nn .init as init
24
24
from megatron .core .tensor_parallel .layers import ColumnParallelLinear , RowParallelLinear
25
+ from megatron .core .transformer .module import MegatronModule
25
26
from megatron .core .transformer .utils import make_sharded_tensors_for_checkpoint
26
27
27
- from ...config import PEFTAttributeConfig
28
- from ..layer import LoRAModule , LoRAModuleRegistry
29
-
30
- try :
31
- from megatron .core .transformer .module import MegatronModule
32
-
33
- from modelopt .torch .quantization .plugins .megatron import (
34
- _MegatronColumnParallelLinear as QuantColumnParallelLinear ,
35
- )
36
- from modelopt .torch .quantization .plugins .megatron import (
37
- _MegatronRowParallelLinear as QuantRowParallelLinear ,
38
- )
39
-
40
- MEGATRON_AVAILABLE = True
41
- except ImportError :
42
- MegatronModule = None
43
- MEGATRON_AVAILABLE = False
28
+ from modelopt .torch .quantization .nn import QuantModuleRegistry
29
+ from modelopt .torch .quantization .plugins .megatron import (
30
+ _MegatronColumnParallelLinear as QuantColumnParallelLinear ,
31
+ )
32
+ from modelopt .torch .quantization .plugins .megatron import (
33
+ _MegatronRowParallelLinear as QuantRowParallelLinear ,
34
+ )
44
35
36
+ from ...config import PEFTAttributeConfig
45
37
from ...custom import CUSTOM_MODEL_PLUGINS
38
+ from ..layer import LoRAModule , LoRAModuleRegistry
46
39
47
40
DEFAULT_LORA_RANK = 64
48
41
DEFAULT_SCALE = 1.0
@@ -60,9 +53,6 @@ def megatron_replace_lora_module_hook(model: torch.nn.Module):
60
53
Note: LoRAModule already has built-in get_extra_state and set_extra_state methods,
61
54
so we don't need to register callbacks for them.
62
55
"""
63
- if not MEGATRON_AVAILABLE :
64
- return
65
-
66
56
for name , module in model .named_modules ():
67
57
if isinstance (module , MegatronModule ):
68
58
# Enable heterogeneous distributed checkpointing
@@ -283,43 +273,41 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
283
273
284
274
285
275
# Register quantized versions if available
286
- if MEGATRON_AVAILABLE :
287
- LoRAModuleRegistry .register ({QuantColumnParallelLinear : "quant_megatron_ColumnParallelLinear" })(
288
- _LoRAMegatronColumnParallelLinear
289
- )
290
- LoRAModuleRegistry .register ({QuantRowParallelLinear : "quant_megatron_RowParallelLinear" })(
291
- _LoRAMegatronRowParallelLinear
292
- )
293
-
294
- from modelopt .torch .quantization .nn import QuantModuleRegistry
295
-
296
- class _QuantLoRAMegatronColumnParallelLinear (
297
- _LoRAMegatronColumnParallelLinear , QuantColumnParallelLinear
298
- ):
299
- """Quantized LoRA ColumnParallelLinear that combines LoRA and quantization.
300
-
301
- This class ensures that the base layer functionality is quantized while
302
- preserving LoRA adapter functionality.
303
- """
276
+ LoRAModuleRegistry .register ({QuantColumnParallelLinear : "quant_megatron_ColumnParallelLinear" })(
277
+ _LoRAMegatronColumnParallelLinear
278
+ )
279
+ LoRAModuleRegistry .register ({QuantRowParallelLinear : "quant_megatron_RowParallelLinear" })(
280
+ _LoRAMegatronRowParallelLinear
281
+ )
304
282
305
- def _setup (self ):
306
- QuantColumnParallelLinear ._setup (self )
307
283
308
- class _QuantLoRAMegatronRowParallelLinear (
309
- _LoRAMegatronRowParallelLinear , QuantRowParallelLinear
310
- ):
311
- """Quantized LoRA RowParallelLinear that combines LoRA and quantization.
284
+ class _QuantLoRAMegatronColumnParallelLinear (
285
+ _LoRAMegatronColumnParallelLinear , QuantColumnParallelLinear
286
+ ):
287
+ """Quantized LoRA ColumnParallelLinear that combines LoRA and quantization.
312
288
313
- This class ensures that the base layer functionality is quantized while
314
- preserving LoRA adapter functionality.
315
- """
289
+ This class ensures that the base layer functionality is quantized while
290
+ preserving LoRA adapter functionality.
291
+ """
292
+
293
+ def _setup (self ):
294
+ QuantColumnParallelLinear ._setup (self )
295
+
296
+
297
+ class _QuantLoRAMegatronRowParallelLinear (_LoRAMegatronRowParallelLinear , QuantRowParallelLinear ):
298
+ """Quantized LoRA RowParallelLinear that combines LoRA and quantization.
299
+
300
+ This class ensures that the base layer functionality is quantized while
301
+ preserving LoRA adapter functionality.
302
+ """
303
+
304
+ def _setup (self ):
305
+ QuantRowParallelLinear ._setup (self )
316
306
317
- def _setup (self ):
318
- QuantRowParallelLinear ._setup (self )
319
307
320
- QuantModuleRegistry .register (
321
- {_LoRAMegatronColumnParallelLinear : "lora_megatron_ColumnParallelLinear" }
322
- )(_QuantLoRAMegatronColumnParallelLinear )
323
- QuantModuleRegistry .register (
324
- { _LoRAMegatronRowParallelLinear : "lora_megatron_RowParallelLinear" }
325
- )( _QuantLoRAMegatronRowParallelLinear )
308
+ QuantModuleRegistry .register (
309
+ {_LoRAMegatronColumnParallelLinear : "lora_megatron_ColumnParallelLinear" }
310
+ )(_QuantLoRAMegatronColumnParallelLinear )
311
+ QuantModuleRegistry .register ({ _LoRAMegatronRowParallelLinear : "lora_megatron_RowParallelLinear" }) (
312
+ _QuantLoRAMegatronRowParallelLinear
313
+ )
0 commit comments