77import torch .nn as nn
88from tqdm import tqdm
99
10- from swift .llm import (ExportArguments , HfConfigFactory , MaxLengthError , ProcessorMixin , deep_getattr , get_model_arch ,
11- load_dataset , prepare_model_template , save_checkpoint , to_device )
10+ from swift .llm import (ExportArguments , HfConfigFactory , MaxLengthError , ProcessorMixin , deep_getattr , load_dataset ,
11+ prepare_model_template , save_checkpoint , to_device )
1212from swift .utils import get_logger , get_model_parameter_info
1313
1414logger = get_logger ()
@@ -160,7 +160,7 @@ def awq_model_quantize(self) -> None:
160160 self .tokenizer , quant_config = quant_config , n_parallel_calib_samples = args .quant_batch_size )
161161 quantizer .get_calib_dataset = _origin_get_calib_dataset # recover
162162 if self .model .quant_config .modules_to_not_convert :
163- model_arch = get_model_arch ( args .model_meta .model_arch )
163+ model_arch = args .model_meta .model_arch
164164 lm_head_key = getattr (model_arch , 'lm_head' , None ) or 'lm_head'
165165 if lm_head_key not in self .model .quant_config .modules_to_not_convert :
166166 self .model .quant_config .modules_to_not_convert .append (lm_head_key )
@@ -180,7 +180,7 @@ def _patch_gptq(self):
180180
181181 @staticmethod
182182 def get_block_name_to_quantize (model : nn .Module ) -> Optional [str ]:
183- model_arch = get_model_arch ( model .model_meta .model_arch )
183+ model_arch = model .model_meta .model_arch
184184 prefix = ''
185185 if hasattr (model_arch , 'language_model' ):
186186 assert len (model_arch .language_model ) == 1 , f'mllm_arch.language_model: { model_arch .language_model } '
0 commit comments