Skip to content

Commit f17ca92

Browse files
support mm llamapro (#2738)
1 parent 64cede0 commit f17ca92

File tree

1 file changed

+57
-10
lines changed

1 file changed

+57
-10
lines changed

swift/tuners/llamapro.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
from copy import deepcopy
3-
from dataclasses import dataclass, field
3+
from dataclasses import dataclass, field, fields
44
from typing import Optional
55

66
import torch
77
from torch import nn
88

9-
from swift.llm import MODEL_ARCH_MAPPING, ModelKeys
9+
from swift.llm import MODEL_ARCH_MAPPING, HfConfigFactory, ModelKeys
1010
from swift.utils.logger import get_logger
1111
from .utils import ActivationMixin, SwiftAdapter, SwiftConfig, SwiftOutput
1212

@@ -46,11 +46,9 @@ class LLaMAPro(SwiftAdapter):
4646
@staticmethod
4747
def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) -> SwiftOutput:
4848
"""Prepare a model with `LLaMAProConfig`"""
49-
num_hidden_layers = None
50-
if hasattr(model.config, 'num_hidden_layers'):
51-
num_hidden_layers = model.config.num_hidden_layers
52-
elif hasattr(model.config, 'num_layers'):
53-
num_hidden_layers = model.config.num_layers
49+
num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_hidden_layers')
50+
if num_hidden_layers is None:
51+
num_hidden_layers = HfConfigFactory.get_config_attr(model.config, 'num_layers')
5452

5553
assert num_hidden_layers is not None, 'Cannot find num of layers config'
5654
assert num_hidden_layers % config.num_new_blocks == 0, f'Model layers {num_hidden_layers} ' \
@@ -60,8 +58,26 @@ def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) -
6058

6159
num_stride = num_hidden_layers // config.num_groups
6260

63-
# We only support decoder only model for now.
64-
module_list = LLaMAPro._find_module_list(config, model)
61+
try:
62+
module_list = LLaMAPro._find_module_list(config, model)
63+
except AssertionError as e:
64+
model_type = LLaMAPro.search_correct_model_type(model)
65+
if model_type is None:
66+
language_model_name = SwiftAdapter.get_model_key_mapping(config.model_type, config).language_model
67+
if language_model_name:
68+
if isinstance(language_model_name, str):
69+
language_model_name = [language_model_name]
70+
language_model = model.get_submodule(language_model_name[0])
71+
model_type = LLaMAPro.search_correct_model_type(language_model)
72+
if model_type:
73+
model = language_model
74+
75+
if model_type:
76+
config.model_type = model_type
77+
module_list = LLaMAPro._find_module_list(config, model)
78+
else:
79+
raise e
80+
6581
new_module_list = nn.ModuleList()
6682
new_module_idx = []
6783
for idx, module in enumerate(module_list):
@@ -107,7 +123,10 @@ def _update_module_attr(config: LLaMAProConfig, module_list):
107123
if model_type in ('llama', 'mistral', 'qwen2', 'yi', 'gemma', 'deepseek', 'openbuddy', 'xverse', 'orion',
108124
'bluelm', 'ziya', 'skywork', 'deepseek-v2', 'minicpm', 'phi3', 'internlm2'):
109125
for idx, module in enumerate(module_list):
110-
getattr(module, attention).layer_idx = idx
126+
try:
127+
getattr(module, attention).layer_idx = idx
128+
except AttributeError:
129+
getattr(module, 'cross_attn').layer_idx = idx
111130
elif model_type in ('chatglm', 'glm4'):
112131
for idx, module in enumerate(module_list):
113132
getattr(module, attention).layer_number = idx
@@ -135,6 +154,34 @@ def get_model_key_mapping(cls, model_type, config) -> ModelKeys:
135154
'LLaMAPro only support models with o_proj and down_proj components.'
136155
return model_key_mapping
137156

157+
@classmethod
158+
def search_correct_model_type(cls, module: nn.Module):
159+
for arch_name, arch_type in MODEL_ARCH_MAPPING.items():
160+
arch_type: ModelKeys
161+
if getattr(arch_type, 'module_list') is None:
162+
# Need to be a LLM arch
163+
continue
164+
165+
matched = True
166+
for f in fields(arch_type):
167+
arch_str = getattr(arch_type, f.name)
168+
if f.name == 'arch_name' or arch_str is None:
169+
continue
170+
171+
arch_str = arch_str.replace('{}', '0')
172+
try:
173+
sub_module = module.get_submodule(arch_str)
174+
if sub_module is None:
175+
matched = False
176+
except AttributeError:
177+
matched = False
178+
179+
if not matched:
180+
break
181+
182+
if matched:
183+
return arch_name
184+
138185
@staticmethod
139186
def _update_module_weight(config: LLaMAProConfig, module_list, new_module_idx):
140187
model_key_mapping = LLaMAPro.get_model_key_mapping(config.model_type, config)

0 commit comments

Comments
 (0)