Skip to content

Commit b055be6

Browse files
authored
Support AWQ & GroupWiseQuant for LLMs (#7688)
* Support AWQ & GroupWiseQuant for LLMs * add docs * add docs * upadte
1 parent a55039c commit b055be6

File tree

5 files changed

+138
-28
lines changed

5 files changed

+138
-28
lines changed

llm/argument.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,20 @@ class QuantArgument:
156156
do_ptq: bool = field(default=False, metadata={"help": "Whether to use PTQ"})
157157
ptq_step: int = field(default=32, metadata={"help": "Step for PTQ"})
158158

159+
weight_quant_method: str = field(
160+
default="abs_max_channel_wise",
161+
metadata={"help": "Weight quantization method, choosen from ['abs_max_channel_wise', 'groupwise']"},
162+
)
163+
164+
# Pre-quant method Shift related parameters
159165
shift: bool = field(default=False, metadata={"help": "Whether to use Shift"})
160166
shift_all_linears: bool = field(default=False, metadata={"help": "Whether to shift all linears"})
161167
shift_sampler: str = field(
162168
default="ema", metadata={"help": "The name of shift sampler, choosen from ['ema', 'none']"}
163169
)
164170
shift_step: int = field(default=32, metadata={"help": "Sample steps when shift"})
165171

172+
# Pre-quant methos Smooth related parameters
166173
smooth: bool = field(default=False, metadata={"help": "Whether to use Smooth"})
167174
smooth_all_linears: bool = field(default=False, metadata={"help": "Whether to smooth all linears"})
168175
smooth_sampler: str = field(
@@ -179,6 +186,12 @@ class QuantArgument:
179186
do_gptq: bool = field(default=False, metadata={"help": "Whether to use GPTQ"})
180187
gptq_step: int = field(default=8, metadata={"help": "Step for GPTQ"})
181188

189+
# AWQ related parameters, default for WINT4
190+
do_awq: bool = field(default=False, metadata={"help": "Whether to use AWQ Search"})
191+
auto_clip: bool = field(default=False, metadata={"help": "Whether to use AutoClip from AWQ"})
192+
awq_step: int = field(default=8, metadata={"help": "Step for AWQ Search"})
193+
autoclip_step: int = field(default=8, metadata={"help": "Step for AutoClip"})
194+
182195

183196
@dataclass
184197
class GenerateArgument:

llm/docs/quantization.md

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
## 1.算法介绍
44

55
大模型量化将16位、32位浮点数的模型参数或激活量化为4位或8位整数能够有效降低模型存储空间和计算资源需求,同时加速推理速度。工具链量化算法包含:
6-
- **PTQ**。PaddleSlim 团队自研的自适应Shift-SmoothQuant量化算法,在[SmoothQuant](https://arxiv.org/abs/2211.10438)[Outlier Suppression+](https://arxiv.org/abs/2304.09145)基础上
6+
- **PTQ**。PaddleSlim 团队自研的自适应PiecewiseSearchSmooth(PSS)量化算法,在[SmoothQuant](https://arxiv.org/abs/2211.10438)[Outlier Suppression+](https://arxiv.org/abs/2304.09145)基础上
77
新增PieceWiseSearch参数搜索算法并将算法扩展至**所有线性层**,对模型权重和激活分布进行调整,减少后续A8W8 PTQ量化损失。
88

99

1010
- **GPTQ**[GPTQ](https://arxiv.org/abs/2210.17323)是业界主流的权重量化算法,可以将大模型权重进行4位整数无损量化,提高模型推理速度。
1111

12+
- **AWQ**[GPTQ](https://arxiv.org/abs/2306.00978)是业界主流的权重量化算法,可以将大模型权重进行4位整数无损量化,提高模型推理速度。
13+
1214
<div align="center">
1315
<img width="800" alt="llm" src="https://github.com/PaddlePaddle/PaddleNLP/assets/63761690/fe8f941b-4b35-48ca-814f-96533d7e24ce">
1416
</div>
@@ -65,12 +67,19 @@ python finetune_generation.py ./llama/ptq_argument.json
6567
python finetune_generation.py ./llama/gptq_argument.json
6668
```
6769

68-
### 2.5 量化参数介绍
70+
### 2.5 AWQ 量化
71+
72+
```
73+
python finetune_generation.py ./llama/awq_argument.json
74+
```
75+
76+
### 2.6 量化参数介绍
6977

7078
<summary>&emsp; 量化参数(QuantArgument)</summary><div>
7179

7280
- `quant_type`: PTQ,QAT量化类型,默认为A8W8。支持A8W8,WINT4,WINT8:A8W8指对激活(输入)进行INT8量化,对模型权重进行INT8量化;WINT4指仅对模型权重进行INT4量化,后续使用WeightOnly进行推理;WINT8指仅对模型权重进行INT8量化,后续使用WeightOnly进行推理。
7381
- `do_ptq`: 是否进行PTQ量化,默认为False。
82+
- `weight_quant_method`: 权重量化方式,现可选groupwise或者abs_max_channel_wise。
7483
- `ptq_step`: PTQ量化步数,也即模型前向次数,默认为32。
7584
- `shift`: 是否在PTQ量化前进行[Shift策略](https://arxiv.org/abs/2304.09145),默认为False。使用Shift策略需要设`do_ptq`为True。
7685
- `shift_all_linear`: 是否对模型中所有Linear层应用Shift,如果为True,将会对非LayerNorm-Linear组合的Linear进行Shift,并且添加两个op,默认为False
@@ -85,6 +94,11 @@ python finetune_generation.py ./llama/gptq_argument.json
8594
- `smooth_search_piece`: 使用分段搜索功能时,是否搜索分段数量,默认为False。设为True时,`smooth_k_piece`建议设为6,搜索分段数量耗时较长,如需加速Smooth过程建议关闭。
8695
- `do_gptq`: 是否进行GPTQ量化,GPTQ对模型进行WINT4量化,相比于普通PTQ量化精度更高,量化时间较长。默认为False。
8796
- `gptq_step`: GPTQ量化步数,也即模型前向次数,默认为8。
97+
- `do_awq`: 是否进行AWQ量化,AWQ对模型进行WINT4量化,相比于普通PTQ量化精度更高。默认为False。
98+
- `auto_clip`: AWQ时是否进行自动搜索截断值并对模型权重进行截断操作,截断操作有利于量化模型精度,但搜索速度较慢。默认为False。
99+
- `autoclip_step`: AutoClip步数,也即模型前向次数,采样时默认concat每轮数据用来搜索截断值,默认为8。
100+
101+
88102
</div>
89103

90104

llm/finetune_generation.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,13 @@ def compute_metrics_do_generation(eval_preds):
556556
raise NotImplementedError(
557557
"PTQ strategy not supported for LoRA model. Please merge lora parameters to pretrain model first."
558558
)
559-
from quant import apply_ptq, apply_shift, apply_smooth, get_ptq_model_config
559+
from quant import (
560+
apply_autoclip,
561+
apply_ptq,
562+
apply_shift,
563+
apply_smooth,
564+
get_ptq_model_config,
565+
)
560566

561567
trainer.model.eval()
562568
trainer.model.config.quantization_config.quant_type = quant_args.quant_type
@@ -575,6 +581,9 @@ def compute_metrics_do_generation(eval_preds):
575581
if quant_args.smooth:
576582
apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config)
577583

584+
if quant_args.auto_clip:
585+
apply_autoclip(quant_args, trainer, ptq_dataloader)
586+
578587
apply_ptq(quant_args, trainer, ptq_dataloader)
579588
trainer.save_model(merge_tensor_parallel=training_args.tensor_parallel_degree > 1)
580589

llm/llama/awq_argument.json

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"model_name_or_path": "./checkpoints/llama_sft_ckpts",
3+
"per_device_train_batch_size": 8,
4+
"per_device_eval_batch_size": 8,
5+
"eval_accumulation_steps":16,
6+
"src_length": 1024,
7+
"max_length": 2048,
8+
"fp16": true,
9+
"fp16_opt_level": "O2",
10+
"dataset_name_or_path": "./data",
11+
"output_dir": "./checkpoints/llama_ptq_ckpts",
12+
"do_eval": true,
13+
"eval_with_do_generation": false,
14+
"do_ptq": true,
15+
"quant_type": "weight_only_int4",
16+
"weight_quant_method": "groupwise",
17+
"ptq_step": 16,
18+
"smooth": true,
19+
"auto_clip": true,
20+
"autoclip_step": 1,
21+
"do_awq": true
22+
}

llm/quant.py

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from paddle.quantization import PTQ, QAT, QuantConfig
2424
from paddleslim.quant.advanced import (
2525
GPTQ,
26+
AutoClip,
27+
AWQSearch,
2628
EMASampler,
2729
MultiStepSampler,
2830
PieceWiseSearch,
@@ -34,11 +36,16 @@
3436
QuantizedColumnParallelLinear,
3537
QuantizedRowParallelLinear,
3638
)
37-
from paddleslim.quant.observers import AbsMaxChannelWiseWeightObserver, AVGObserver
39+
from paddleslim.quant.observers import (
40+
AbsMaxChannelWiseWeightObserver,
41+
AVGObserver,
42+
GroupWiseWeightObserver,
43+
)
3844
from paddleslim.quant.observers.abs_max_weight import (
3945
AbsMaxChannelWiseWeightObserverLayer,
4046
)
4147
from paddleslim.quant.observers.avg import AVGObserverLayer
48+
from paddleslim.quant.observers.groupwise import GroupWiseWeightObserverLayer
4249

4350
from paddlenlp.peft import PrefixModelForCausalLM
4451
from paddlenlp.peft.lora import (
@@ -96,20 +103,23 @@ def apply_shift(quant_args, trainer, ptq_dataloader, ptq_model_config):
96103
sample_function=shift_sampler,
97104
shift_all_linears=quant_args.shift_all_linears,
98105
)
99-
100-
trainer.ptq_loop(
101-
ptq_dataloader,
102-
description="Shift",
103-
max_eval_iters=quant_args.shift_step,
104-
)
105-
shift.update_weight()
106+
with paddle.no_grad():
107+
trainer.ptq_loop(
108+
ptq_dataloader,
109+
description="Shift",
110+
max_eval_iters=quant_args.shift_step,
111+
)
112+
shift.update_weight()
106113
del shift, shift_sampler
107114
logger.info("***** Shift done *****")
108115

109116

110117
def apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config):
111118

112-
logger.info("***** Running Smooth *****")
119+
if quant_args.do_awq:
120+
logger.info("***** Running AWQ *****")
121+
else:
122+
logger.info("***** Running Smooth *****")
113123
smooth_sampler = MultiStepSampler() if quant_args.smooth_sampler == "multi_step" else None
114124
if quant_args.smooth_piecewise_search:
115125
search_func = PieceWiseSearch(
@@ -123,6 +133,12 @@ def apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config):
123133
weight_quant_method="abs_max_channel_wise",
124134
act_quant_method="avg",
125135
)
136+
elif quant_args.do_awq:
137+
search_func = AWQSearch(
138+
n_grid=20,
139+
bits_length=4,
140+
weight_quant_method=quant_args.weight_quant_method,
141+
)
126142
else:
127143
search_func = None
128144
smooth = Smooth(
@@ -132,31 +148,64 @@ def apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config):
132148
smooth_all_linears=quant_args.smooth_all_linears,
133149
sample_function=smooth_sampler,
134150
search_function=search_func,
151+
smooth_method="awq" if quant_args.do_awq else "smoothquant",
135152
)
136-
trainer.ptq_loop(
137-
ptq_dataloader,
138-
description="Smooth",
139-
max_eval_iters=quant_args.smooth_step,
140-
)
153+
with paddle.no_grad():
154+
trainer.ptq_loop(
155+
ptq_dataloader,
156+
description="Smooth",
157+
max_eval_iters=quant_args.smooth_step,
158+
)
141159

142-
smooth.update_weight()
160+
smooth.update_weight()
143161
del smooth, smooth_sampler, search_func
144162
logger.info("***** Smooth done *****")
145163

146164

165+
def apply_autoclip(quant_args, trainer, ptq_dataloader):
166+
"""
167+
AutoClip
168+
"""
169+
print("-------------------Start AutoClip------------------")
170+
sampler = MultiStepSampler()
171+
auto_clip = AutoClip(
172+
trainer.model,
173+
weight_bits=4,
174+
weight_quant_method=quant_args.weight_quant_method,
175+
sample_function=sampler,
176+
n_grid=20,
177+
max_shrink=0.5,
178+
)
179+
with paddle.no_grad():
180+
trainer.ptq_loop(
181+
ptq_dataloader,
182+
description="AutoClip",
183+
max_eval_iters=quant_args.autoclip_step,
184+
)
185+
auto_clip.auto_clip()
186+
del sampler, auto_clip
187+
logger.info("***** AutoClip done *****")
188+
189+
147190
def apply_ptq(quant_args, trainer, ptq_dataloader):
148191
logger.info("***** Running PTQ *****")
149192
q_config = QuantConfig(activation=None, weight=None)
193+
if quant_args.weight_quant_method == "abs_max_channel_wise":
194+
weight_observer = AbsMaxChannelWiseWeightObserver
195+
elif quant_args.weight_quant_method == "groupwise":
196+
weight_observer = GroupWiseWeightObserver
197+
else:
198+
raise ValueError("weight_quant_method should be one of ['abs_max_channel_wise', 'groupwise']")
150199

151200
if quant_args.quant_type == "a8w8":
152201
activation = AVGObserver(quant_bits=8)
153-
weight = AbsMaxChannelWiseWeightObserver(quant_bits=8)
202+
weight = weight_observer(quant_bits=8)
154203
elif quant_args.quant_type == "weight_only_int4":
155204
activation = None
156-
weight = AbsMaxChannelWiseWeightObserver(quant_bits=4)
205+
weight = weight_observer(quant_bits=4)
157206
elif quant_args.quant_type == "weight_only_int8":
158207
activation = None
159-
weight = AbsMaxChannelWiseWeightObserver(quant_bits=8)
208+
weight = weight_observer(quant_bits=8)
160209
else:
161210
raise ValueError("quant_type should be one of ['a8w8', 'weight_only_int4', 'weight_only_int8']")
162211

@@ -181,10 +230,12 @@ def apply_ptq(quant_args, trainer, ptq_dataloader):
181230
if isinstance(cur_layer, AbsMaxChannelWiseWeightObserverLayer):
182231
if "_observer" not in cur_name:
183232
weight_scales[cur_name] = cur_layer.scales().numpy().tolist()
233+
if isinstance(cur_layer, GroupWiseWeightObserverLayer):
234+
if "_observer" not in cur_name:
235+
weight_scales[cur_name] = cur_layer.scales().numpy().tolist()
184236
if isinstance(cur_layer, AVGObserverLayer):
185237
if "_observer" not in cur_name:
186238
act_scales[cur_name] = cur_layer.scales().numpy().tolist()
187-
188239
weight_scales_path = os.path.join(trainer.args.output_dir, "weight_scales.json")
189240
with open(weight_scales_path, "w") as f:
190241
json.dump(weight_scales, f)
@@ -210,12 +261,13 @@ def apply_gptq(quant_args, trainer, ptq_dataloader):
210261
parent_layer, sub_name = find_parent_layer_and_sub_name(model, cur_name)
211262
cur_quant_layer = GPTQ(cur_layer)
212263
setattr(parent_layer, sub_name, cur_quant_layer)
213-
trainer.ptq_loop(
214-
ptq_dataloader,
215-
description="GPTQ",
216-
max_eval_iters=quant_args.gptq_step,
217-
)
218-
cur_quant_layer.fasterquant(percdamp=0.1, groupsize=-1, actorder=True)
264+
with paddle.no_grad():
265+
trainer.ptq_loop(
266+
ptq_dataloader,
267+
description="GPTQ",
268+
max_eval_iters=quant_args.gptq_step,
269+
)
270+
cur_quant_layer.fasterquant(percdamp=0.1, groupsize=-1, actorder=True)
219271
del cur_quant_layer
220272
setattr(parent_layer, sub_name, cur_layer)
221273
logger.info("***** GPTQ done *****")

0 commit comments

Comments
 (0)