Skip to content

Commit 7ec3e5b

Browse files
committed
Force to check the mcore model
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 1eb6677 commit 7ec3e5b

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

modelopt/torch/peft/conversion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def convert_to_peft_model(model: ModelLikeModule, config: PEFTConfig) -> Convert
3838
# initialize the true module if necessary
3939
model = model.init_modellike() if isinstance(model, ModelLikeModule) else model
4040

41-
# TODO: Replace to LoRA module
4241
replace_lora_module(model, version=ModeloptStateManager(model).state_version, config=config)
4342

4443
metadata = {}

modelopt/torch/peft/convert.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,18 @@
2424
from modelopt.torch.peft.config import PEFTConfig
2525
from modelopt.torch.peft.conversion import add_adapter
2626

27+
try:
28+
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
29+
30+
MEGATRON_LAYERS = (ColumnParallelLinear, RowParallelLinear)
31+
except Exception:
32+
MEGATRON_LAYERS = ()
33+
2734
from .lora.layer import LoRAModule
2835
from .mode import PEFTModeRegistry
2936

37+
MEGATRON_LAYERS = (ColumnParallelLinear, RowParallelLinear)
38+
3039
__all__ = [
3140
"disable_adapters",
3241
"enable_adapters",
@@ -53,6 +62,8 @@ def update_model(
5362
Returns:
5463
The updated model with LoRA adapters
5564
"""
65+
assert is_megatron_core_model(model), "We only support mcore format for the PEFT mode"
66+
5667
# Check if model is already in PEFT mode by looking for LoRA modules
5768
if not is_peft_model(model):
5869
apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry)
@@ -218,3 +229,11 @@ def get_adapter_states(model):
218229
adapter_states[module_name] = module_adapters
219230

220231
return adapter_states
232+
233+
234+
def is_megatron_core_model(model) -> bool:
235+
if MEGATRON_LAYERS:
236+
for m in model.modules():
237+
if isinstance(m, MEGATRON_LAYERS):
238+
return True
239+
return False

0 commit comments

Comments
 (0)