Skip to content

Commit 408d500

Browse files
Feat/qwen1.5 (#385)
1 parent 31b2010 commit 408d500

File tree

4 files changed

+287
-26
lines changed

4 files changed

+287
-26
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用
6262

6363

6464
## 🎉 News
65+
- 🔥2024.02.05: Support qwen1.5 series: like [qwen1.5-0.5b](https://www.modelscope.cn/models/qwen/Qwen1.5-0.5B/summary), [qwen1.5-7b](https://www.modelscope.cn/models/qwen/Qwen1.5-7B/summary),[qwen1.5-14b](https://www.modelscope.cn/models/qwen/Qwen1.5-14B/summary) , etc. To view all supported qwen1.5 models please check [Model List](https://github.com/modelscope/swift/blob/main/docs/source/LLM/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.md).
6566
- 2024.02.01: Support openbmb-minicpm series: [openbmb-minicpm-2b-sft-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/openbmb_minicpm_2b_sft_chat), openbmb-minicpm-2b-chat.
6667
- 🔥2024.02.01: Support dataset mixture to reduce **Catastrophic Forgetting**. Use `--train_dataset_mix_ratio 2.0` to train! We also provide a common knowledge dataset [ms-bench](https://www.modelscope.cn/datasets/iic/ms_bench/summary).
6768
- 🔥2024.02.01: Support Agent training! Agent training algorithm comes from this [paper](https://arxiv.org/pdf/2309.00986.pdf). We also introduce the [ms-agent](https://www.modelscope.cn/datasets/iic/ms_agent/summary) dataset. Use [this script](https://github.com/modelscope/swift/blob/main/examples/pytorch/llm/scripts/qwen_7b_chat/lora/sft.sh) to begin an agent training!

README_CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展
6060
用户可以查看 [SWIFT官方文档](docs/source/GetStarted/快速使用.md) 来了解详细信息。
6161

6262
## 🎉 新闻
63+
- 🔥2024.02.05: 支持qwen1.5系列模型: [qwen1.5-0.5b](https://www.modelscope.cn/models/qwen/Qwen1.5-0.5B/summary), [qwen1.5-7b](https://www.modelscope.cn/models/qwen/Qwen1.5-7B/summary),[qwen1.5-14b](https://www.modelscope.cn/models/qwen/Qwen1.5-14B/summary)等, 支持的所有qwen1.5系列模型请查看[模型列表](https://github.com/modelscope/swift/blob/main/docs/source/LLM/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.md).
6364
- 2024.02.01: 支持openbmb-minicpm系列: [openbmb-minicpm-2b-sft-chat](https://github.com/modelscope/swift/tree/main/examples/pytorch/llm/scripts/openbmb_minicpm_2b_sft_chat), openbmb-minicpm-2b-chat.
6465
- 🔥2024.02.01: 支持数据集打混来减少 **灾难性遗忘问题**. 使用`--train_dataset_mix_ratio 2.0`开启训练!同时我们也开源了通用知识数据集 [ms-bench](https://www.modelscope.cn/datasets/iic/ms_bench/summary).
6566
- 🔥2024.02.01: 支持Agent训练!Agent训练算法源自这篇[论文](https://arxiv.org/pdf/2309.00986.pdf). 我们也增加了[ms-agent](https://www.modelscope.cn/datasets/iic/ms_agent/summary)这个优质的agent数据集. 使用[这个脚本](https://github.com/modelscope/swift/blob/main/examples/pytorch/llm/scripts/qwen_7b_chat/lora/sft.sh)开启Agent训练!

docs/source/LLM/支持的模型和数据集.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,30 @@
3030
|qwen-72b-chat|[qwen/Qwen-72B-Chat](https://modelscope.cn/models/qwen/Qwen-72B-Chat/summary)|c_attn|qwen|✔|✔||
3131
|qwen-72b-chat-int4|[qwen/Qwen-72B-Chat-Int4](https://modelscope.cn/models/qwen/Qwen-72B-Chat-Int4/summary)|c_attn|qwen|✔|✘|auto_gptq>=0.5|
3232
|qwen-72b-chat-int8|[qwen/Qwen-72B-Chat-Int8](https://modelscope.cn/models/qwen/Qwen-72B-Chat-Int8/summary)|c_attn|qwen|✔|✘|auto_gptq>=0.5|
33+
|qwen1half-0_5b|[qwen/Qwen1.5-0_5B](https://modelscope.cn/models/qwen/Qwen1.5-0_5B/summary)|q_proj, k_proj, v_proj|default-generation|✔|✔|transformers>=4.37|
34+
|qwen1half-1_8b|[qwen/Qwen1.5-1_8B](https://modelscope.cn/models/qwen/Qwen1.5-1_8B/summary)|q_proj, k_proj, v_proj|default-generation|✔|✔|transformers>=4.37|
35+
|qwen1half-4b|[qwen/Qwen1.5-4B](https://modelscope.cn/models/qwen/Qwen1.5-4B/summary)|q_proj, k_proj, v_proj|default-generation|✔|✔|transformers>=4.37|
36+
|qwen1half-7b|[qwen/Qwen1.5-7B](https://modelscope.cn/models/qwen/Qwen1.5-7B/summary)|q_proj, k_proj, v_proj|default-generation|✔|✔|transformers>=4.37|
37+
|qwen1half-14b|[qwen/Qwen1.5-14B](https://modelscope.cn/models/qwen/Qwen1.5-14B/summary)|q_proj, k_proj, v_proj|default-generation|✔|✔|transformers>=4.37|
38+
|qwen1half-72b|[qwen/Qwen1.5-72B](https://modelscope.cn/models/qwen/Qwen1.5-72B/summary)|q_proj, k_proj, v_proj|default-generation|✔|✔|transformers>=4.37|
39+
|qwen1half-0_5b-chat|[qwen/Qwen1.5-0_5B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-0_5B-Chat/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|transformers>=4.37|
40+
|qwen1half-1_8b-chat|[qwen/Qwen1.5-1_8B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-1_8B-Chat/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|transformers>=4.37|
41+
|qwen1half-4b-chat|[qwen/Qwen1.5-4B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-4B-Chat/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|transformers>=4.37|
42+
|qwen1half-7b-chat|[qwen/Qwen1.5-7B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-7B-Chat/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|transformers>=4.37|
43+
|qwen1half-14b-chat|[qwen/Qwen1.5-14B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-14B-Chat/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|transformers>=4.37|
44+
|qwen1half-72b-chat|[qwen/Qwen1.5-72B-Chat](https://modelscope.cn/models/qwen/Qwen1.5-72B-Chat/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|transformers>=4.37|
45+
|qwen1half-0_5b-chat-int8|[qwen/Qwen1.5-0_5B-Chat-GPTQ-Int8](https://modelscope.cn/models/qwen/Qwen1.5-0_5B-Chat-GPTQ-Int8/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|auto_gptq>=0.5, transformers>=4.37|
46+
|qwen1half-0_5b-chat-int4|[qwen/Qwen1.5-0_5B-Chat-GPTQ-Int4](https://modelscope.cn/models/qwen/Qwen1.5-0_5B-Chat-GPTQ-Int4/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|auto_gptq>=0.5, transformers>=4.37|
47+
|qwen1half-1_8b-chat-int8|[qwen/Qwen1.5-1_8B-Chat-GPTQ-Int8](https://modelscope.cn/models/qwen/Qwen1.5-1_8B-Chat-GPTQ-Int8/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|auto_gptq>=0.5, transformers>=4.37|
48+
|qwen1half-1_8b-chat-int4|[qwen/Qwen1.5-1_8B-Chat-GPTQ-Int4](https://modelscope.cn/models/qwen/Qwen1.5-1_8B-Chat-GPTQ-Int4/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|auto_gptq>=0.5, transformers>=4.37|
49+
|qwen1half-4b-chat-int8|[qwen/Qwen1.5-4B-Chat-GPTQ-Int8](https://modelscope.cn/models/qwen/Qwen1.5-4B-Chat-GPTQ-Int8/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|auto_gptq>=0.5, transformers>=4.37|
50+
|qwen1half-4b-chat-int4|[qwen/Qwen1.5-4B-Chat-GPTQ-Int4](https://modelscope.cn/models/qwen/Qwen1.5-4B-Chat-GPTQ-Int4/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|auto_gptq>=0.5, transformers>=4.37|
51+
|qwen1half-7b-chat-int8|[qwen/Qwen1.5-7B-Chat-GPTQ-Int8](https://modelscope.cn/models/qwen/Qwen1.5-7B-Chat-GPTQ-Int8/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|auto_gptq>=0.5, transformers>=4.37|
52+
|qwen1half-7b-chat-int4|[qwen/Qwen1.5-7B-Chat-GPTQ-Int4](https://modelscope.cn/models/qwen/Qwen1.5-7B-Chat-GPTQ-Int4/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|auto_gptq>=0.5, transformers>=4.37|
53+
|qwen1half-14b-chat-int8|[qwen/Qwen1.5-14B-Chat-GPTQ-Int8](https://modelscope.cn/models/qwen/Qwen1.5-14B-Chat-GPTQ-Int8/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|auto_gptq>=0.5, transformers>=4.37|
54+
|qwen1half-14b-chat-int4|[qwen/Qwen1.5-14B-Chat-GPTQ-Int4](https://modelscope.cn/models/qwen/Qwen1.5-14B-Chat-GPTQ-Int4/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|auto_gptq>=0.5, transformers>=4.37|
55+
|qwen1half-72b-chat-int8|[qwen/Qwen1.5-72B-Chat-GPTQ-Int8](https://modelscope.cn/models/qwen/Qwen1.5-72B-Chat-GPTQ-Int8/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|auto_gptq>=0.5, transformers>=4.37|
56+
|qwen1half-72b-chat-int4|[qwen/Qwen1.5-72B-Chat-GPTQ-Int4](https://modelscope.cn/models/qwen/Qwen1.5-72B-Chat-GPTQ-Int4/summary)|q_proj, k_proj, v_proj|chatml|✔|✔|auto_gptq>=0.5, transformers>=4.37|
3357
|qwen-vl|[qwen/Qwen-VL](https://modelscope.cn/models/qwen/Qwen-VL/summary)|c_attn|default-generation|✔|✘||
3458
|qwen-vl-chat|[qwen/Qwen-VL-Chat](https://modelscope.cn/models/qwen/Qwen-VL-Chat/summary)|c_attn|qwen|✔|✘||
3559
|qwen-vl-chat-int4|[qwen/Qwen-VL-Chat-Int4](https://modelscope.cn/models/qwen/Qwen-VL-Chat-Int4/summary)|c_attn|qwen|✔|✘|auto_gptq>=0.5|

swift/llm/utils/model.py

Lines changed: 261 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,32 @@ class ModelType:
5151
qwen_72b_chat = 'qwen-72b-chat'
5252
qwen_72b_chat_int4 = 'qwen-72b-chat-int4'
5353
qwen_72b_chat_int8 = 'qwen-72b-chat-int8'
54-
# qwen2
55-
qwen2_beta_0_5b = 'qwen2-beta-0_5b'
56-
qwen2_beta_1_8b = 'qwen2-beta-1_8b'
57-
qwen2_beta_4b = 'qwen2-beta-4b'
58-
qwen2_beta_7b = 'qwen2-beta-7b'
59-
qwen2_beta_14b = 'qwen2-beta-14b'
60-
qwen2_beta_72b = 'qwen2-beta-72b'
54+
# qwen1.5
55+
qwen1half_0_5b = 'qwen1half-0_5b'
56+
qwen1half_1_8b = 'qwen1half-1_8b'
57+
qwen1half_4b = 'qwen1half-4b'
58+
qwen1half_7b = 'qwen1half-7b'
59+
qwen1half_14b = 'qwen1half-14b'
60+
qwen1half_72b = 'qwen1half-72b'
61+
qwen1half_0_5b_chat = 'qwen1half-0_5b-chat'
62+
qwen1half_1_8b_chat = 'qwen1half-1_8b-chat'
63+
qwen1half_4b_chat = 'qwen1half-4b-chat'
64+
qwen1half_7b_chat = 'qwen1half-7b-chat'
65+
qwen1half_14b_chat = 'qwen1half-14b-chat'
66+
qwen1half_72b_chat = 'qwen1half-72b-chat'
67+
# qwen1.5 autogptq
68+
qwen1half_0_5b_chat_int8 = 'qwen1half-0_5b-chat-int8'
69+
qwen1half_0_5b_chat_int4 = 'qwen1half-0_5b-chat-int4'
70+
qwen1half_1_8b_chat_int8 = 'qwen1half-1_8b-chat-int8'
71+
qwen1half_1_8b_chat_int4 = 'qwen1half-1_8b-chat-int4'
72+
qwen1half_4b_chat_int8 = 'qwen1half-4b-chat-int8'
73+
qwen1half_4b_chat_int4 = 'qwen1half-4b-chat-int4'
74+
qwen1half_7b_chat_int8 = 'qwen1half-7b-chat-int8'
75+
qwen1half_7b_chat_int4 = 'qwen1half-7b-chat-int4'
76+
qwen1half_14b_chat_int8 = 'qwen1half-14b-chat-int8'
77+
qwen1half_14b_chat_int4 = 'qwen1half-14b-chat-int4'
78+
qwen1half_72b_chat_int8 = 'qwen1half-72b-chat-int8'
79+
qwen1half_72b_chat_int4 = 'qwen1half-72b-chat-int4'
6180
# qwen-vl
6281
qwen_vl = 'qwen-vl'
6382
qwen_vl_chat = 'qwen-vl-chat'
@@ -219,7 +238,7 @@ class LoRATM(NamedTuple):
219238
chatglm = ['query_key_value']
220239
llama2 = ['q_proj', 'k_proj', 'v_proj']
221240
qwen = ['c_attn']
222-
qwen2 = llama2
241+
qwen1half = llama2
223242
polylm = ['c_attn']
224243
bloom = ['query_key_value']
225244
cogagent = [
@@ -694,53 +713,101 @@ def cross_entropy_forward(self, inputs: Tensor,
694713

695714

696715
@register_model(
697-
ModelType.qwen2_beta_0_5b,
698-
'qwen/Qwen2-beta-0_5B',
699-
LoRATM.qwen2,
716+
ModelType.qwen1half_0_5b,
717+
'qwen/Qwen1.5-0.5B',
718+
LoRATM.qwen1half,
700719
TemplateType.default_generation,
701720
support_flash_attn=True,
702721
support_vllm=True,
703722
requires=['transformers>=4.37'])
704723
@register_model(
705-
ModelType.qwen2_beta_1_8b,
706-
'qwen/Qwen2-beta-1_8B',
707-
LoRATM.qwen2,
724+
ModelType.qwen1half_0_5b_chat,
725+
'qwen/Qwen1.5-0.5B-Chat',
726+
LoRATM.qwen1half,
727+
TemplateType.chatml,
728+
support_flash_attn=True,
729+
support_vllm=True,
730+
requires=['transformers>=4.37'])
731+
@register_model(
732+
ModelType.qwen1half_1_8b,
733+
'qwen/Qwen1.5-1.8B',
734+
LoRATM.qwen1half,
708735
TemplateType.default_generation,
709736
support_flash_attn=True,
710737
support_vllm=True,
711738
requires=['transformers>=4.37'])
712739
@register_model(
713-
ModelType.qwen2_beta_4b,
714-
'qwen/Qwen2-beta-4B',
715-
LoRATM.qwen2,
740+
ModelType.qwen1half_1_8b_chat,
741+
'qwen/Qwen1.5-1.8B-Chat',
742+
LoRATM.qwen1half,
743+
TemplateType.chatml,
744+
support_flash_attn=True,
745+
support_vllm=True,
746+
requires=['transformers>=4.37'])
747+
@register_model(
748+
ModelType.qwen1half_4b,
749+
'qwen/Qwen1.5-4B',
750+
LoRATM.qwen1half,
716751
TemplateType.default_generation,
717752
support_flash_attn=True,
718753
support_vllm=True,
719754
requires=['transformers>=4.37'])
720755
@register_model(
721-
ModelType.qwen2_beta_7b,
722-
'qwen/Qwen2-beta-7B',
723-
LoRATM.qwen2,
756+
ModelType.qwen1half_4b_chat,
757+
'qwen/Qwen1.5-4B-Chat',
758+
LoRATM.qwen1half,
759+
TemplateType.chatml,
760+
support_flash_attn=True,
761+
support_vllm=True,
762+
requires=['transformers>=4.37'])
763+
@register_model(
764+
ModelType.qwen1half_7b,
765+
'qwen/Qwen1.5-7B',
766+
LoRATM.qwen1half,
724767
TemplateType.default_generation,
725768
support_flash_attn=True,
726769
support_vllm=True,
727770
requires=['transformers>=4.37'])
728771
@register_model(
729-
ModelType.qwen2_beta_14b,
730-
'qwen/Qwen2-beta-14B',
731-
LoRATM.qwen2,
772+
ModelType.qwen1half_7b_chat,
773+
'qwen/Qwen1.5-7B-Chat',
774+
LoRATM.qwen1half,
775+
TemplateType.chatml,
776+
support_flash_attn=True,
777+
support_vllm=True,
778+
requires=['transformers>=4.37'])
779+
@register_model(
780+
ModelType.qwen1half_14b,
781+
'qwen/Qwen1.5-14B',
782+
LoRATM.qwen1half,
732783
TemplateType.default_generation,
733784
support_flash_attn=True,
734785
support_vllm=True,
735786
requires=['transformers>=4.37'])
736787
@register_model(
737-
ModelType.qwen2_beta_72b,
738-
'qwen/Qwen2-beta-72B',
739-
LoRATM.qwen2,
788+
ModelType.qwen1half_14b_chat,
789+
'qwen/Qwen1.5-14B-Chat',
790+
LoRATM.qwen1half,
791+
TemplateType.chatml,
792+
support_flash_attn=True,
793+
support_vllm=True,
794+
requires=['transformers>=4.37'])
795+
@register_model(
796+
ModelType.qwen1half_72b,
797+
'qwen/Qwen1.5-72B',
798+
LoRATM.qwen1half,
740799
TemplateType.default_generation,
741800
support_flash_attn=True,
742801
support_vllm=True,
743802
requires=['transformers>=4.37'])
803+
@register_model(
804+
ModelType.qwen1half_72b_chat,
805+
'qwen/Qwen1.5-72B-Chat',
806+
LoRATM.qwen1half,
807+
TemplateType.chatml,
808+
support_flash_attn=True,
809+
support_vllm=True,
810+
requires=['transformers>=4.37'])
744811
@register_model(
745812
ModelType.deepseek_coder_1_3b,
746813
'deepseek-ai/deepseek-coder-1.3b-base',
@@ -997,6 +1064,174 @@ def get_model_tokenizer_with_flash_attn(model_dir: str,
9971064
**kwargs)
9981065

9991066

1067+
@register_model(
1068+
ModelType.qwen1half_0_5b_chat_int4,
1069+
'qwen/Qwen1.5-0.5B-Chat-GPTQ-Int4',
1070+
LoRATM.qwen1half,
1071+
TemplateType.chatml,
1072+
requires=['auto_gptq>=0.5', 'transformers>=4.37'],
1073+
torch_dtype=torch.float16,
1074+
function_kwargs={'bits': 4},
1075+
support_flash_attn=True,
1076+
support_vllm=True)
1077+
@register_model(
1078+
ModelType.qwen1half_0_5b_chat_int8,
1079+
'qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8',
1080+
LoRATM.qwen1half,
1081+
TemplateType.chatml,
1082+
requires=['auto_gptq>=0.5', 'transformers>=4.37'],
1083+
torch_dtype=torch.float16,
1084+
function_kwargs={'bits': 8},
1085+
support_flash_attn=True,
1086+
support_vllm=True)
1087+
@register_model(
1088+
ModelType.qwen1half_1_8b_chat_int4,
1089+
'qwen/Qwen1.5-1.8B-Chat-GPTQ-Int4',
1090+
LoRATM.qwen1half,
1091+
TemplateType.chatml,
1092+
requires=['auto_gptq>=0.5', 'transformers>=4.37'],
1093+
torch_dtype=torch.float16,
1094+
function_kwargs={'bits': 4},
1095+
support_flash_attn=True,
1096+
support_vllm=True)
1097+
@register_model(
1098+
ModelType.qwen1half_1_8b_chat_int8,
1099+
'qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8',
1100+
LoRATM.qwen1half,
1101+
TemplateType.chatml,
1102+
requires=['auto_gptq>=0.5', 'transformers>=4.37'],
1103+
torch_dtype=torch.float16,
1104+
function_kwargs={'bits': 8},
1105+
support_flash_attn=True,
1106+
support_vllm=True)
1107+
@register_model(
1108+
ModelType.qwen1half_4b_chat_int4,
1109+
'qwen/Qwen1.5-4B-Chat-GPTQ-Int4',
1110+
LoRATM.qwen1half,
1111+
TemplateType.chatml,
1112+
requires=['auto_gptq>=0.5', 'transformers>=4.37'],
1113+
torch_dtype=torch.float16,
1114+
function_kwargs={'bits': 4},
1115+
support_flash_attn=True,
1116+
support_vllm=True)
1117+
@register_model(
1118+
ModelType.qwen1half_4b_chat_int8,
1119+
'qwen/Qwen1.5-4B-Chat-GPTQ-Int8',
1120+
LoRATM.qwen1half,
1121+
TemplateType.chatml,
1122+
requires=['auto_gptq>=0.5', 'transformers>=4.37'],
1123+
torch_dtype=torch.float16,
1124+
function_kwargs={'bits': 8},
1125+
support_flash_attn=True,
1126+
support_vllm=True)
1127+
@register_model(
1128+
ModelType.qwen1half_7b_chat_int4,
1129+
'qwen/Qwen1.5-7B-Chat-GPTQ-Int4',
1130+
LoRATM.qwen1half,
1131+
TemplateType.chatml,
1132+
requires=['auto_gptq>=0.5', 'transformers>=4.37'],
1133+
torch_dtype=torch.float16,
1134+
function_kwargs={'bits': 4},
1135+
support_flash_attn=True,
1136+
support_vllm=True)
1137+
@register_model(
1138+
ModelType.qwen1half_7b_chat_int8,
1139+
'qwen/Qwen1.5-7B-Chat-GPTQ-Int8',
1140+
LoRATM.qwen1half,
1141+
TemplateType.chatml,
1142+
requires=['auto_gptq>=0.5', 'transformers>=4.37'],
1143+
torch_dtype=torch.float16,
1144+
function_kwargs={'bits': 8},
1145+
support_flash_attn=True,
1146+
support_vllm=True)
1147+
@register_model(
1148+
ModelType.qwen1half_14b_chat_int4,
1149+
'qwen/Qwen1.5-14B-Chat-GPTQ-Int4',
1150+
LoRATM.qwen1half,
1151+
TemplateType.chatml,
1152+
requires=['auto_gptq>=0.5', 'transformers>=4.37'],
1153+
torch_dtype=torch.float16,
1154+
function_kwargs={'bits': 4},
1155+
support_flash_attn=True,
1156+
support_vllm=True)
1157+
@register_model(
1158+
ModelType.qwen1half_14b_chat_int8,
1159+
'qwen/Qwen1.5-14B-Chat-GPTQ-Int8',
1160+
LoRATM.qwen1half,
1161+
TemplateType.chatml,
1162+
requires=['auto_gptq>=0.5', 'transformers>=4.37'],
1163+
torch_dtype=torch.float16,
1164+
function_kwargs={'bits': 8},
1165+
support_flash_attn=True,
1166+
support_vllm=True)
1167+
@register_model(
1168+
ModelType.qwen1half_72b_chat_int4,
1169+
'qwen/Qwen1.5-72B-Chat-GPTQ-Int4',
1170+
LoRATM.qwen1half,
1171+
TemplateType.chatml,
1172+
requires=['auto_gptq>=0.5', 'transformers>=4.37'],
1173+
torch_dtype=torch.float16,
1174+
function_kwargs={'bits': 4},
1175+
support_flash_attn=True,
1176+
support_vllm=True)
1177+
@register_model(
1178+
ModelType.qwen1half_72b_chat_int8,
1179+
'qwen/Qwen1.5-72B-Chat-GPTQ-Int8',
1180+
LoRATM.qwen1half,
1181+
TemplateType.chatml,
1182+
requires=['auto_gptq>=0.5', 'transformers>=4.37'],
1183+
torch_dtype=torch.float16,
1184+
function_kwargs={'bits': 8},
1185+
support_flash_attn=True,
1186+
support_vllm=True)
1187+
def get_model_tokenizer_with_flash_attn_intx(model_dir: str,
1188+
torch_dtype: Dtype,
1189+
model_kwargs: Dict[str, Any],
1190+
load_model: bool = True,
1191+
model_config=None,
1192+
**kwargs):
1193+
if model_config is None:
1194+
model_config = AutoConfig.from_pretrained(
1195+
model_dir, trust_remote_code=True)
1196+
use_flash_attn = kwargs.pop('use_flash_attn', False)
1197+
if version.parse(transformers.__version__) >= version.parse('4.36'):
1198+
if use_flash_attn:
1199+
model_config._attn_implementation = 'flash_attention_2'
1200+
else:
1201+
model_config._flash_attn_2_enabled = use_flash_attn
1202+
1203+
logger.info('use gptq, ignore bnb arguments')
1204+
bits = kwargs.pop('bits')
1205+
if version.parse(transformers.__version__) >= version.parse('4.35'):
1206+
model_kwargs['quantization_config'] = GPTQConfig(
1207+
bits=bits, use_exllama=False)
1208+
else:
1209+
model_kwargs['quantization_config'] = GPTQConfig(
1210+
bits=bits, disable_exllama=True)
1211+
1212+
# fix quantlinear bug
1213+
from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import QuantLinear
1214+
__old_forward = QuantLinear.forward
1215+
1216+
def _new_forward(self, x):
1217+
if not self.training or not self.autogptq_cuda_available:
1218+
return self.__old_forward(x)
1219+
# fix sft no grad
1220+
self.autogptq_cuda_available = False
1221+
res = self.__old_forward(x)
1222+
self.autogptq_cuda_available = True
1223+
return res
1224+
1225+
if not hasattr(QuantLinear, '__old_forward'): # avoid double patching
1226+
QuantLinear.__old_forward = __old_forward
1227+
QuantLinear.forward = _new_forward
1228+
get_qwen_function = kwargs.pop('get_qwen_function',
1229+
get_model_tokenizer_with_flash_attn)
1230+
model, tokenizer = get_qwen_function(model_dir, torch_dtype, model_kwargs,
1231+
load_model, **kwargs)
1232+
return model, tokenizer
1233+
1234+
10001235
@register_model(
10011236
ModelType.internlm2_math_7b,
10021237
'Shanghai_AI_Laboratory/internlm2-math-base-7b',

0 commit comments

Comments
 (0)