Skip to content

Commit 8fc1c41

Browse files
authored
Fix baichuan2 int4 bug (#400)
1 parent fa85682 commit 8fc1c41

File tree

3 files changed

+30
-47
lines changed

3 files changed

+30
-47
lines changed

docs/source/LLM/Agent微调最佳实践.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
- [环境安装](#环境安装)
66
- [数据准备](#数据准备)
77
- [微调](#微调)
8-
- [推理](#微调后推理)
8+
- [推理](#推理)
99
- [总结](#总结)
1010

1111
## 环境安装

swift/llm/utils/model.py

Lines changed: 9 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,8 @@ def get_model_tokenizer_baichuan2_int4(model_dir: str,
622622
if device_map != 'auto':
623623
accelerate.infer_auto_device_map = _old_infer_auto_device_map
624624
if model is not None:
625+
model.config.quantization_config = BitsAndBytesConfig(
626+
**model.config.quantization_config)
625627
model.train()
626628
model._is_quantized_training_enabled = True
627629
model.is_loaded_in_4bit = True
@@ -1186,52 +1188,15 @@ def get_model_tokenizer_with_flash_attn(model_dir: str,
11861188
function_kwargs={'bits': 8},
11871189
support_flash_attn=True,
11881190
support_vllm=True)
1189-
def get_model_tokenizer_with_flash_attn_intx(model_dir: str,
1190-
torch_dtype: Dtype,
1191-
model_kwargs: Dict[str, Any],
1192-
load_model: bool = True,
1193-
model_config=None,
1194-
**kwargs):
1195-
if model_config is None:
1196-
model_config = AutoConfig.from_pretrained(
1197-
model_dir, trust_remote_code=True)
1198-
use_flash_attn = kwargs.pop('use_flash_attn', False)
1199-
if version.parse(transformers.__version__) >= version.parse('4.36'):
1200-
if use_flash_attn:
1201-
model_config._attn_implementation = 'flash_attention_2'
1202-
else:
1203-
model_config._flash_attn_2_enabled = use_flash_attn
1204-
1205-
logger.info('use gptq, ignore bnb arguments')
1206-
bits = kwargs.pop('bits')
1207-
if version.parse(transformers.__version__) >= version.parse('4.35'):
1208-
model_kwargs['quantization_config'] = GPTQConfig(
1209-
bits=bits, use_exllama=False)
1210-
else:
1211-
model_kwargs['quantization_config'] = GPTQConfig(
1212-
bits=bits, disable_exllama=True)
1213-
1214-
# fix quantlinear bug
1215-
from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import QuantLinear
1216-
__old_forward = QuantLinear.forward
1217-
1218-
def _new_forward(self, x):
1219-
if not self.training or not self.autogptq_cuda_available:
1220-
return self.__old_forward(x)
1221-
# fix sft no grad
1222-
self.autogptq_cuda_available = False
1223-
res = self.__old_forward(x)
1224-
self.autogptq_cuda_available = True
1225-
return res
1191+
def get_model_tokenizer_with_qwen1half_intx(model_dir: str,
1192+
torch_dtype: Dtype,
1193+
model_kwargs: Dict[str, Any],
1194+
load_model: bool = True,
1195+
**kwargs):
12261196

1227-
if not hasattr(QuantLinear, '__old_forward'): # avoid double patching
1228-
QuantLinear.__old_forward = __old_forward
1229-
QuantLinear.forward = _new_forward
1230-
get_qwen_function = kwargs.pop('get_qwen_function',
1231-
get_model_tokenizer_with_flash_attn)
1232-
model, tokenizer = get_qwen_function(model_dir, torch_dtype, model_kwargs,
1197+
kwargs['get_qwen_function'] = get_model_tokenizer_with_flash_attn
1198+
return get_model_tokenizer_qwen_intx(model_dir, torch_dtype, model_kwargs,
12331199
load_model, **kwargs)
1234-
return model, tokenizer
12351200

12361201

12371202
@register_model(

tests/llm/test_run.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def test_loss_matching(self):
115115
infer_main([
116116
'--ckpt_dir', best_model_checkpoint, '--show_dataset_sample',
117117
str(show_dataset_sample), '--max_new_tokens', '100',
118-
'--use_flash_attn', 'true', '--verbose',
118+
'--use_flash_attn', 'false', '--verbose',
119119
str(not bool_var), '--merge_lora_and_save',
120120
str(bool_var), '--load_dataset_config',
121121
str(load_dataset_config)
@@ -220,7 +220,7 @@ def test_self_cognition(self):
220220
self_cognition_sample=100,
221221
model_name=['小黄', 'Xiao Huang'],
222222
model_author=['魔搭', 'ModelScope'],
223-
use_flash_attn=False)
223+
use_flash_attn=True)
224224
torch.cuda.empty_cache()
225225
output = sft_main(sft_args)
226226
last_model_checkpoint = output['last_model_checkpoint']
@@ -350,6 +350,24 @@ def test_pai_compat(self):
350350
infer_main([infer_json])
351351
os.environ.pop('PAI_TRAINING_JOB_ID')
352352

353+
def test_baichuan2_chat_int4(self):
354+
if not __name__ == '__main__':
355+
# ignore citest error in github
356+
return
357+
from swift.llm import sft_main, infer_main, SftArguments, InferArguments, ModelType, DatasetName
358+
output = sft_main(
359+
SftArguments(
360+
model_type=ModelType.baichuan2_7b_chat_int4,
361+
dataset=['alpaca-zh'],
362+
lora_target_modules=['DEFAULT', 'EMBEDDING'],
363+
train_dataset_sample=20))
364+
best_model_checkpoint = output['best_model_checkpoint']
365+
infer_main(
366+
InferArguments(
367+
ckpt_dir=best_model_checkpoint,
368+
load_dataset_config=True,
369+
val_dataset_sample=1))
370+
353371

354372
def data_collate_fn(batch: List[Dict[str, Any]],
355373
tokenizer) -> Dict[str, Tensor]:

0 commit comments

Comments
 (0)