23
23
from paddle .quantization import PTQ , QAT , QuantConfig
24
24
from paddleslim .quant .advanced import (
25
25
GPTQ ,
26
+ AutoClip ,
27
+ AWQSearch ,
26
28
EMASampler ,
27
29
MultiStepSampler ,
28
30
PieceWiseSearch ,
34
36
QuantizedColumnParallelLinear ,
35
37
QuantizedRowParallelLinear ,
36
38
)
37
- from paddleslim .quant .observers import AbsMaxChannelWiseWeightObserver , AVGObserver
39
+ from paddleslim .quant .observers import (
40
+ AbsMaxChannelWiseWeightObserver ,
41
+ AVGObserver ,
42
+ GroupWiseWeightObserver ,
43
+ )
38
44
from paddleslim .quant .observers .abs_max_weight import (
39
45
AbsMaxChannelWiseWeightObserverLayer ,
40
46
)
41
47
from paddleslim .quant .observers .avg import AVGObserverLayer
48
+ from paddleslim .quant .observers .groupwise import GroupWiseWeightObserverLayer
42
49
43
50
from paddlenlp .peft import PrefixModelForCausalLM
44
51
from paddlenlp .peft .lora import (
@@ -96,20 +103,23 @@ def apply_shift(quant_args, trainer, ptq_dataloader, ptq_model_config):
96
103
sample_function = shift_sampler ,
97
104
shift_all_linears = quant_args .shift_all_linears ,
98
105
)
99
-
100
- trainer .ptq_loop (
101
- ptq_dataloader ,
102
- description = "Shift" ,
103
- max_eval_iters = quant_args .shift_step ,
104
- )
105
- shift .update_weight ()
106
+ with paddle . no_grad ():
107
+ trainer .ptq_loop (
108
+ ptq_dataloader ,
109
+ description = "Shift" ,
110
+ max_eval_iters = quant_args .shift_step ,
111
+ )
112
+ shift .update_weight ()
106
113
del shift , shift_sampler
107
114
logger .info ("***** Shift done *****" )
108
115
109
116
110
117
def apply_smooth (quant_args , trainer , ptq_dataloader , ptq_model_config ):
111
118
112
- logger .info ("***** Running Smooth *****" )
119
+ if quant_args .do_awq :
120
+ logger .info ("***** Running AWQ *****" )
121
+ else :
122
+ logger .info ("***** Running Smooth *****" )
113
123
smooth_sampler = MultiStepSampler () if quant_args .smooth_sampler == "multi_step" else None
114
124
if quant_args .smooth_piecewise_search :
115
125
search_func = PieceWiseSearch (
@@ -123,6 +133,12 @@ def apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config):
123
133
weight_quant_method = "abs_max_channel_wise" ,
124
134
act_quant_method = "avg" ,
125
135
)
136
+ elif quant_args .do_awq :
137
+ search_func = AWQSearch (
138
+ n_grid = 20 ,
139
+ bits_length = 4 ,
140
+ weight_quant_method = quant_args .weight_quant_method ,
141
+ )
126
142
else :
127
143
search_func = None
128
144
smooth = Smooth (
@@ -132,31 +148,64 @@ def apply_smooth(quant_args, trainer, ptq_dataloader, ptq_model_config):
132
148
smooth_all_linears = quant_args .smooth_all_linears ,
133
149
sample_function = smooth_sampler ,
134
150
search_function = search_func ,
151
+ smooth_method = "awq" if quant_args .do_awq else "smoothquant" ,
135
152
)
136
- trainer .ptq_loop (
137
- ptq_dataloader ,
138
- description = "Smooth" ,
139
- max_eval_iters = quant_args .smooth_step ,
140
- )
153
+ with paddle .no_grad ():
154
+ trainer .ptq_loop (
155
+ ptq_dataloader ,
156
+ description = "Smooth" ,
157
+ max_eval_iters = quant_args .smooth_step ,
158
+ )
141
159
142
- smooth .update_weight ()
160
+ smooth .update_weight ()
143
161
del smooth , smooth_sampler , search_func
144
162
logger .info ("***** Smooth done *****" )
145
163
146
164
165
+ def apply_autoclip (quant_args , trainer , ptq_dataloader ):
166
+ """
167
+ AutoClip
168
+ """
169
+ print ("-------------------Start AutoClip------------------" )
170
+ sampler = MultiStepSampler ()
171
+ auto_clip = AutoClip (
172
+ trainer .model ,
173
+ weight_bits = 4 ,
174
+ weight_quant_method = quant_args .weight_quant_method ,
175
+ sample_function = sampler ,
176
+ n_grid = 20 ,
177
+ max_shrink = 0.5 ,
178
+ )
179
+ with paddle .no_grad ():
180
+ trainer .ptq_loop (
181
+ ptq_dataloader ,
182
+ description = "AutoClip" ,
183
+ max_eval_iters = quant_args .autoclip_step ,
184
+ )
185
+ auto_clip .auto_clip ()
186
+ del sampler , auto_clip
187
+ logger .info ("***** AutoClip done *****" )
188
+
189
+
147
190
def apply_ptq (quant_args , trainer , ptq_dataloader ):
148
191
logger .info ("***** Running PTQ *****" )
149
192
q_config = QuantConfig (activation = None , weight = None )
193
+ if quant_args .weight_quant_method == "abs_max_channel_wise" :
194
+ weight_observer = AbsMaxChannelWiseWeightObserver
195
+ elif quant_args .weight_quant_method == "groupwise" :
196
+ weight_observer = GroupWiseWeightObserver
197
+ else :
198
+ raise ValueError ("weight_quant_method should be one of ['abs_max_channel_wise', 'groupwise']" )
150
199
151
200
if quant_args .quant_type == "a8w8" :
152
201
activation = AVGObserver (quant_bits = 8 )
153
- weight = AbsMaxChannelWiseWeightObserver (quant_bits = 8 )
202
+ weight = weight_observer (quant_bits = 8 )
154
203
elif quant_args .quant_type == "weight_only_int4" :
155
204
activation = None
156
- weight = AbsMaxChannelWiseWeightObserver (quant_bits = 4 )
205
+ weight = weight_observer (quant_bits = 4 )
157
206
elif quant_args .quant_type == "weight_only_int8" :
158
207
activation = None
159
- weight = AbsMaxChannelWiseWeightObserver (quant_bits = 8 )
208
+ weight = weight_observer (quant_bits = 8 )
160
209
else :
161
210
raise ValueError ("quant_type should be one of ['a8w8', 'weight_only_int4', 'weight_only_int8']" )
162
211
@@ -181,10 +230,12 @@ def apply_ptq(quant_args, trainer, ptq_dataloader):
181
230
if isinstance (cur_layer , AbsMaxChannelWiseWeightObserverLayer ):
182
231
if "_observer" not in cur_name :
183
232
weight_scales [cur_name ] = cur_layer .scales ().numpy ().tolist ()
233
+ if isinstance (cur_layer , GroupWiseWeightObserverLayer ):
234
+ if "_observer" not in cur_name :
235
+ weight_scales [cur_name ] = cur_layer .scales ().numpy ().tolist ()
184
236
if isinstance (cur_layer , AVGObserverLayer ):
185
237
if "_observer" not in cur_name :
186
238
act_scales [cur_name ] = cur_layer .scales ().numpy ().tolist ()
187
-
188
239
weight_scales_path = os .path .join (trainer .args .output_dir , "weight_scales.json" )
189
240
with open (weight_scales_path , "w" ) as f :
190
241
json .dump (weight_scales , f )
@@ -210,12 +261,13 @@ def apply_gptq(quant_args, trainer, ptq_dataloader):
210
261
parent_layer , sub_name = find_parent_layer_and_sub_name (model , cur_name )
211
262
cur_quant_layer = GPTQ (cur_layer )
212
263
setattr (parent_layer , sub_name , cur_quant_layer )
213
- trainer .ptq_loop (
214
- ptq_dataloader ,
215
- description = "GPTQ" ,
216
- max_eval_iters = quant_args .gptq_step ,
217
- )
218
- cur_quant_layer .fasterquant (percdamp = 0.1 , groupsize = - 1 , actorder = True )
264
+ with paddle .no_grad ():
265
+ trainer .ptq_loop (
266
+ ptq_dataloader ,
267
+ description = "GPTQ" ,
268
+ max_eval_iters = quant_args .gptq_step ,
269
+ )
270
+ cur_quant_layer .fasterquant (percdamp = 0.1 , groupsize = - 1 , actorder = True )
219
271
del cur_quant_layer
220
272
setattr (parent_layer , sub_name , cur_layer )
221
273
logger .info ("***** GPTQ done *****" )
0 commit comments