11# Copyright (c) Alibaba, Inc. and its affiliates.
22from copy import deepcopy
3- from dataclasses import dataclass , field
3+ from dataclasses import dataclass , field , fields
44from typing import Optional
55
66import torch
77from torch import nn
88
9- from swift .llm import MODEL_ARCH_MAPPING , ModelKeys
9+ from swift .llm import MODEL_ARCH_MAPPING , HfConfigFactory , ModelKeys
1010from swift .utils .logger import get_logger
1111from .utils import ActivationMixin , SwiftAdapter , SwiftConfig , SwiftOutput
1212
@@ -46,11 +46,9 @@ class LLaMAPro(SwiftAdapter):
4646 @staticmethod
4747 def prepare_model (model : nn .Module , config : LLaMAProConfig , adapter_name : str ) -> SwiftOutput :
4848 """Prepare a model with `LLaMAProConfig`"""
49- num_hidden_layers = None
50- if hasattr (model .config , 'num_hidden_layers' ):
51- num_hidden_layers = model .config .num_hidden_layers
52- elif hasattr (model .config , 'num_layers' ):
53- num_hidden_layers = model .config .num_layers
49+ num_hidden_layers = HfConfigFactory .get_config_attr (model .config , 'num_hidden_layers' )
50+ if num_hidden_layers is None :
51+ num_hidden_layers = HfConfigFactory .get_config_attr (model .config , 'num_layers' )
5452
5553 assert num_hidden_layers is not None , 'Cannot find num of layers config'
5654 assert num_hidden_layers % config .num_new_blocks == 0 , f'Model layers { num_hidden_layers } ' \
@@ -60,8 +58,26 @@ def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) -
6058
6159 num_stride = num_hidden_layers // config .num_groups
6260
63- # We only support decoder only model for now.
64- module_list = LLaMAPro ._find_module_list (config , model )
61+ try :
62+ module_list = LLaMAPro ._find_module_list (config , model )
63+ except AssertionError as e :
64+ model_type = LLaMAPro .search_correct_model_type (model )
65+ if model_type is None :
66+ language_model_name = SwiftAdapter .get_model_key_mapping (config .model_type , config ).language_model
67+ if language_model_name :
68+ if isinstance (language_model_name , str ):
69+ language_model_name = [language_model_name ]
70+ language_model = model .get_submodule (language_model_name [0 ])
71+ model_type = LLaMAPro .search_correct_model_type (language_model )
72+ if model_type :
73+ model = language_model
74+
75+ if model_type :
76+ config .model_type = model_type
77+ module_list = LLaMAPro ._find_module_list (config , model )
78+ else :
79+ raise e
80+
6581 new_module_list = nn .ModuleList ()
6682 new_module_idx = []
6783 for idx , module in enumerate (module_list ):
@@ -107,7 +123,10 @@ def _update_module_attr(config: LLaMAProConfig, module_list):
107123 if model_type in ('llama' , 'mistral' , 'qwen2' , 'yi' , 'gemma' , 'deepseek' , 'openbuddy' , 'xverse' , 'orion' ,
108124 'bluelm' , 'ziya' , 'skywork' , 'deepseek-v2' , 'minicpm' , 'phi3' , 'internlm2' ):
109125 for idx , module in enumerate (module_list ):
110- getattr (module , attention ).layer_idx = idx
126+ try :
127+ getattr (module , attention ).layer_idx = idx
128+ except AttributeError :
129+ getattr (module , 'cross_attn' ).layer_idx = idx
111130 elif model_type in ('chatglm' , 'glm4' ):
112131 for idx , module in enumerate (module_list ):
113132 getattr (module , attention ).layer_number = idx
@@ -135,6 +154,34 @@ def get_model_key_mapping(cls, model_type, config) -> ModelKeys:
135154 'LLaMAPro only support models with o_proj and down_proj components.'
136155 return model_key_mapping
137156
157+ @classmethod
158+ def search_correct_model_type (cls , module : nn .Module ):
159+ for arch_name , arch_type in MODEL_ARCH_MAPPING .items ():
160+ arch_type : ModelKeys
161+ if getattr (arch_type , 'module_list' ) is None :
162+ # Need to be a LLM arch
163+ continue
164+
165+ matched = True
166+ for f in fields (arch_type ):
167+ arch_str = getattr (arch_type , f .name )
168+ if f .name == 'arch_name' or arch_str is None :
169+ continue
170+
171+ arch_str = arch_str .replace ('{}' , '0' )
172+ try :
173+ sub_module = module .get_submodule (arch_str )
174+ if sub_module is None :
175+ matched = False
176+ except AttributeError :
177+ matched = False
178+
179+ if not matched :
180+ break
181+
182+ if matched :
183+ return arch_name
184+
138185 @staticmethod
139186 def _update_module_weight (config : LLaMAProConfig , module_list , new_module_idx ):
140187 model_key_mapping = LLaMAPro .get_model_key_mapping (config .model_type , config )
0 commit comments