Skip to content

Commit ce6bead

Browse files
committed
Remove the import error check
Signed-off-by: Jingyu Xin <[email protected]>
1 parent ffef564 commit ce6bead

File tree

1 file changed

+43
-55
lines changed

1 file changed

+43
-55
lines changed

modelopt/torch/peft/lora/plugins/megatron.py

Lines changed: 43 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,20 @@
2222
import torch.nn as nn
2323
import torch.nn.init as init
2424
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
25+
from megatron.core.transformer.module import MegatronModule
2526
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
2627

27-
from ...config import PEFTAttributeConfig
28-
from ..layer import LoRAModule, LoRAModuleRegistry
29-
30-
try:
31-
from megatron.core.transformer.module import MegatronModule
32-
33-
from modelopt.torch.quantization.plugins.megatron import (
34-
_MegatronColumnParallelLinear as QuantColumnParallelLinear,
35-
)
36-
from modelopt.torch.quantization.plugins.megatron import (
37-
_MegatronRowParallelLinear as QuantRowParallelLinear,
38-
)
39-
40-
MEGATRON_AVAILABLE = True
41-
except ImportError:
42-
MegatronModule = None
43-
MEGATRON_AVAILABLE = False
28+
from modelopt.torch.quantization.nn import QuantModuleRegistry
29+
from modelopt.torch.quantization.plugins.megatron import (
30+
_MegatronColumnParallelLinear as QuantColumnParallelLinear,
31+
)
32+
from modelopt.torch.quantization.plugins.megatron import (
33+
_MegatronRowParallelLinear as QuantRowParallelLinear,
34+
)
4435

36+
from ...config import PEFTAttributeConfig
4537
from ...custom import CUSTOM_MODEL_PLUGINS
38+
from ..layer import LoRAModule, LoRAModuleRegistry
4639

4740
DEFAULT_LORA_RANK = 64
4841
DEFAULT_SCALE = 1.0
@@ -60,9 +53,6 @@ def megatron_replace_lora_module_hook(model: torch.nn.Module):
6053
Note: LoRAModule already has built-in get_extra_state and set_extra_state methods,
6154
so we don't need to register callbacks for them.
6255
"""
63-
if not MEGATRON_AVAILABLE:
64-
return
65-
6656
for name, module in model.named_modules():
6757
if isinstance(module, MegatronModule):
6858
# Enable heterogeneous distributed checkpointing
@@ -283,43 +273,41 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
283273

284274

285275
# Register quantized versions if available
286-
if MEGATRON_AVAILABLE:
287-
LoRAModuleRegistry.register({QuantColumnParallelLinear: "quant_megatron_ColumnParallelLinear"})(
288-
_LoRAMegatronColumnParallelLinear
289-
)
290-
LoRAModuleRegistry.register({QuantRowParallelLinear: "quant_megatron_RowParallelLinear"})(
291-
_LoRAMegatronRowParallelLinear
292-
)
293-
294-
from modelopt.torch.quantization.nn import QuantModuleRegistry
295-
296-
class _QuantLoRAMegatronColumnParallelLinear(
297-
_LoRAMegatronColumnParallelLinear, QuantColumnParallelLinear
298-
):
299-
"""Quantized LoRA ColumnParallelLinear that combines LoRA and quantization.
300-
301-
This class ensures that the base layer functionality is quantized while
302-
preserving LoRA adapter functionality.
303-
"""
276+
LoRAModuleRegistry.register({QuantColumnParallelLinear: "quant_megatron_ColumnParallelLinear"})(
277+
_LoRAMegatronColumnParallelLinear
278+
)
279+
LoRAModuleRegistry.register({QuantRowParallelLinear: "quant_megatron_RowParallelLinear"})(
280+
_LoRAMegatronRowParallelLinear
281+
)
304282

305-
def _setup(self):
306-
QuantColumnParallelLinear._setup(self)
307283

308-
class _QuantLoRAMegatronRowParallelLinear(
309-
_LoRAMegatronRowParallelLinear, QuantRowParallelLinear
310-
):
311-
"""Quantized LoRA RowParallelLinear that combines LoRA and quantization.
284+
class _QuantLoRAMegatronColumnParallelLinear(
285+
_LoRAMegatronColumnParallelLinear, QuantColumnParallelLinear
286+
):
287+
"""Quantized LoRA ColumnParallelLinear that combines LoRA and quantization.
312288
313-
This class ensures that the base layer functionality is quantized while
314-
preserving LoRA adapter functionality.
315-
"""
289+
This class ensures that the base layer functionality is quantized while
290+
preserving LoRA adapter functionality.
291+
"""
292+
293+
def _setup(self):
294+
QuantColumnParallelLinear._setup(self)
295+
296+
297+
class _QuantLoRAMegatronRowParallelLinear(_LoRAMegatronRowParallelLinear, QuantRowParallelLinear):
298+
"""Quantized LoRA RowParallelLinear that combines LoRA and quantization.
299+
300+
This class ensures that the base layer functionality is quantized while
301+
preserving LoRA adapter functionality.
302+
"""
303+
304+
def _setup(self):
305+
QuantRowParallelLinear._setup(self)
316306

317-
def _setup(self):
318-
QuantRowParallelLinear._setup(self)
319307

320-
QuantModuleRegistry.register(
321-
{_LoRAMegatronColumnParallelLinear: "lora_megatron_ColumnParallelLinear"}
322-
)(_QuantLoRAMegatronColumnParallelLinear)
323-
QuantModuleRegistry.register(
324-
{_LoRAMegatronRowParallelLinear: "lora_megatron_RowParallelLinear"}
325-
)(_QuantLoRAMegatronRowParallelLinear)
308+
QuantModuleRegistry.register(
309+
{_LoRAMegatronColumnParallelLinear: "lora_megatron_ColumnParallelLinear"}
310+
)(_QuantLoRAMegatronColumnParallelLinear)
311+
QuantModuleRegistry.register({_LoRAMegatronRowParallelLinear: "lora_megatron_RowParallelLinear"})(
312+
_QuantLoRAMegatronRowParallelLinear
313+
)

0 commit comments

Comments
 (0)