@@ -622,6 +622,8 @@ def get_model_tokenizer_baichuan2_int4(model_dir: str,
622622 if device_map != 'auto' :
623623 accelerate .infer_auto_device_map = _old_infer_auto_device_map
624624 if model is not None :
625+ model .config .quantization_config = BitsAndBytesConfig (
626+ ** model .config .quantization_config )
625627 model .train ()
626628 model ._is_quantized_training_enabled = True
627629 model .is_loaded_in_4bit = True
@@ -1186,52 +1188,15 @@ def get_model_tokenizer_with_flash_attn(model_dir: str,
11861188 function_kwargs = {'bits' : 8 },
11871189 support_flash_attn = True ,
11881190 support_vllm = True )
1189- def get_model_tokenizer_with_flash_attn_intx (model_dir : str ,
1190- torch_dtype : Dtype ,
1191- model_kwargs : Dict [str , Any ],
1192- load_model : bool = True ,
1193- model_config = None ,
1194- ** kwargs ):
1195- if model_config is None :
1196- model_config = AutoConfig .from_pretrained (
1197- model_dir , trust_remote_code = True )
1198- use_flash_attn = kwargs .pop ('use_flash_attn' , False )
1199- if version .parse (transformers .__version__ ) >= version .parse ('4.36' ):
1200- if use_flash_attn :
1201- model_config ._attn_implementation = 'flash_attention_2'
1202- else :
1203- model_config ._flash_attn_2_enabled = use_flash_attn
1204-
1205- logger .info ('use gptq, ignore bnb arguments' )
1206- bits = kwargs .pop ('bits' )
1207- if version .parse (transformers .__version__ ) >= version .parse ('4.35' ):
1208- model_kwargs ['quantization_config' ] = GPTQConfig (
1209- bits = bits , use_exllama = False )
1210- else :
1211- model_kwargs ['quantization_config' ] = GPTQConfig (
1212- bits = bits , disable_exllama = True )
1213-
1214- # fix quantlinear bug
1215- from auto_gptq .nn_modules .qlinear .qlinear_cuda_old import QuantLinear
1216- __old_forward = QuantLinear .forward
1217-
1218- def _new_forward (self , x ):
1219- if not self .training or not self .autogptq_cuda_available :
1220- return self .__old_forward (x )
1221- # fix sft no grad
1222- self .autogptq_cuda_available = False
1223- res = self .__old_forward (x )
1224- self .autogptq_cuda_available = True
1225- return res
1191+ def get_model_tokenizer_with_qwen1half_intx (model_dir : str ,
1192+ torch_dtype : Dtype ,
1193+ model_kwargs : Dict [str , Any ],
1194+ load_model : bool = True ,
1195+ ** kwargs ):
12261196
1227- if not hasattr (QuantLinear , '__old_forward' ): # avoid double patching
1228- QuantLinear .__old_forward = __old_forward
1229- QuantLinear .forward = _new_forward
1230- get_qwen_function = kwargs .pop ('get_qwen_function' ,
1231- get_model_tokenizer_with_flash_attn )
1232- model , tokenizer = get_qwen_function (model_dir , torch_dtype , model_kwargs ,
1197+ kwargs ['get_qwen_function' ] = get_model_tokenizer_with_flash_attn
1198+ return get_model_tokenizer_qwen_intx (model_dir , torch_dtype , model_kwargs ,
12331199 load_model , ** kwargs )
1234- return model , tokenizer
12351200
12361201
12371202@register_model (
0 commit comments