@@ -189,10 +189,10 @@ def eval_quant_model():
189
189
valid_data_num = 0
190
190
max_eval_data_num = 200
191
191
if g_quant_config .eval_sample_generator is not None :
192
- feed_dict = False
192
+ feed_dict = False
193
193
eval_dataloader = g_quant_config .eval_sample_generator
194
194
else :
195
- feed_dict = True
195
+ feed_dict = True
196
196
eval_dataloader = g_quant_config .eval_dataloader
197
197
for i , data in enumerate (eval_dataloader ()):
198
198
with paddle .static .scope_guard (float_scope ):
@@ -236,12 +236,20 @@ def eval_quant_model():
236
236
def quantize (cfg ):
237
237
"""model quantize job"""
238
238
algo = cfg ["algo" ] if 'algo' in cfg else g_quant_config .algo [0 ][0 ]
239
- hist_percent = cfg ["hist_percent" ] if "hist_percent" in cfg else g_quant_config .hist_percent [0 ][0 ]
240
- bias_correct = cfg ["bias_correct" ] if "bias_correct" in cfg else g_quant_config .bias_correct [0 ][0 ]
241
- batch_size = cfg ["batch_size" ] if "batch_size" in cfg else g_quant_config .batch_size [0 ][0 ]
242
- batch_num = cfg ["batch_num" ] if "batch_num" in cfg else g_quant_config .batch_num [0 ][0 ]
243
- weight_quantize_type = cfg ["weight_quantize_type" ] if "weight_quantize_type" in cfg else g_quant_config .weight_quantize_type [0 ]
244
- print (hist_percent , bias_correct , batch_size , batch_num , weight_quantize_type )
239
+ hist_percent = cfg [
240
+ "hist_percent" ] if "hist_percent" in cfg else g_quant_config .hist_percent [
241
+ 0 ][0 ]
242
+ bias_correct = cfg [
243
+ "bias_correct" ] if "bias_correct" in cfg else g_quant_config .bias_correct [
244
+ 0 ][0 ]
245
+ batch_size = cfg [
246
+ "batch_size" ] if "batch_size" in cfg else g_quant_config .batch_size [0 ][
247
+ 0 ]
248
+ batch_num = cfg [
249
+ "batch_num" ] if "batch_num" in cfg else g_quant_config .batch_num [0 ][0 ]
250
+ weight_quantize_type = cfg [
251
+ "weight_quantize_type" ] if "weight_quantize_type" in cfg else g_quant_config .weight_quantize_type [
252
+ 0 ]
245
253
246
254
quant_post ( \
247
255
executor = g_quant_config .executor , \
@@ -279,34 +287,35 @@ def quantize(cfg):
279
287
return emd_loss
280
288
281
289
282
- def quant_post_hpo (executor ,
283
- place ,
284
- model_dir ,
285
- quantize_model_path ,
286
- train_sample_generator = None ,
287
- eval_sample_generator = None ,
288
- train_dataloader = None ,
289
- eval_dataloader = None ,
290
- eval_function = None ,
291
- model_filename = None ,
292
- params_filename = None ,
293
- save_model_filename = '__model__' ,
294
- save_params_filename = '__params__' ,
295
- scope = None ,
296
- quantizable_op_type = ["conv2d" , "depthwise_conv2d" , "mul" ],
297
- is_full_quantize = False ,
298
- weight_bits = 8 ,
299
- activation_bits = 8 ,
300
- weight_quantize_type = ['channel_wise_abs_max' ],
301
- algo = ["KL" , "hist" , "avg" , "mse" ],
302
- bias_correct = [True , False ],
303
- hist_percent = [0.98 , 0.999 ], ### uniform sample in list.
304
- batch_size = [10 , 30 ], ### uniform sample in list.
305
- batch_num = [10 , 30 ], ### uniform sample in list.
306
- optimize_model = False ,
307
- is_use_cache_file = False ,
308
- cache_dir = "./temp_post_training" ,
309
- runcount_limit = 30 ):
290
+ def quant_post_hpo (
291
+ executor ,
292
+ place ,
293
+ model_dir ,
294
+ quantize_model_path ,
295
+ train_sample_generator = None ,
296
+ eval_sample_generator = None ,
297
+ train_dataloader = None ,
298
+ eval_dataloader = None ,
299
+ eval_function = None ,
300
+ model_filename = None ,
301
+ params_filename = None ,
302
+ save_model_filename = '__model__' ,
303
+ save_params_filename = '__params__' ,
304
+ scope = None ,
305
+ quantizable_op_type = ["conv2d" , "depthwise_conv2d" , "mul" ],
306
+ is_full_quantize = False ,
307
+ weight_bits = 8 ,
308
+ activation_bits = 8 ,
309
+ weight_quantize_type = ['channel_wise_abs_max' ],
310
+ algo = ["KL" , "hist" , "avg" , "mse" ],
311
+ bias_correct = [True , False ],
312
+ hist_percent = [0.98 , 0.999 ], ### uniform sample in list.
313
+ batch_size = [10 , 30 ], ### uniform sample in list.
314
+ batch_num = [10 , 30 ], ### uniform sample in list.
315
+ optimize_model = False ,
316
+ is_use_cache_file = False ,
317
+ cache_dir = "./temp_post_training" ,
318
+ runcount_limit = 30 ):
310
319
"""
311
320
The function utilizes static post training quantization method to
312
321
quantize the fp32 model. It uses calibrate data to calculate the
@@ -360,25 +369,27 @@ def quant_post_hpo(executor,
360
369
361
370
global g_quant_config
362
371
g_quant_config = QuantConfig (
363
- executor , place , model_dir , quantize_model_path , algo , hist_percent ,
372
+ executor , place , model_dir , quantize_model_path , algo , hist_percent ,
364
373
bias_correct , batch_size , batch_num , train_sample_generator ,
365
374
eval_sample_generator , train_dataloader , eval_dataloader , eval_function ,
366
- model_filename , params_filename ,
367
- save_model_filename , save_params_filename , scope , quantizable_op_type ,
368
- is_full_quantize , weight_bits , activation_bits , weight_quantize_type ,
369
- optimize_model , is_use_cache_file , cache_dir )
375
+ model_filename , params_filename , save_model_filename ,
376
+ save_params_filename , scope , quantizable_op_type , is_full_quantize ,
377
+ weight_bits , activation_bits , weight_quantize_type , optimize_model ,
378
+ is_use_cache_file , cache_dir )
370
379
cs = ConfigurationSpace ()
371
380
372
381
hyper_params = []
373
382
374
383
if 'hist' in algo :
375
384
hist_percent = UniformFloatHyperparameter (
376
- "hist_percent" , hist_percent [0 ], hist_percent [1 ], default_value = hist_percent [0 ])
385
+ "hist_percent" ,
386
+ hist_percent [0 ],
387
+ hist_percent [1 ],
388
+ default_value = hist_percent [0 ])
377
389
hyper_params .append (hist_percent )
378
390
379
391
if len (algo ) > 1 :
380
- algo = CategoricalHyperparameter (
381
- "algo" , algo , default_value = algo [0 ])
392
+ algo = CategoricalHyperparameter ("algo" , algo , default_value = algo [0 ])
382
393
hyper_params .append (algo )
383
394
else :
384
395
algo = algo [0 ]
@@ -397,7 +408,10 @@ def quant_post_hpo(executor,
397
408
weight_quantize_type = weight_quantize_type [0 ]
398
409
if len (batch_size ) > 1 :
399
410
batch_size = UniformIntegerHyperparameter (
400
- "batch_size" , batch_size [0 ], batch_size [1 ], default_value = batch_size [0 ])
411
+ "batch_size" ,
412
+ batch_size [0 ],
413
+ batch_size [1 ],
414
+ default_value = batch_size [0 ])
401
415
hyper_params .append (batch_size )
402
416
else :
403
417
batch_size = batch_size [0 ]
@@ -407,7 +421,7 @@ def quant_post_hpo(executor,
407
421
"batch_num" , batch_num [0 ], batch_num [1 ], default_value = batch_num [0 ])
408
422
hyper_params .append (batch_num )
409
423
else :
410
- batch_num = batch_num [0 ]
424
+ batch_num = batch_num [0 ]
411
425
412
426
if len (hyper_params ) == 0 :
413
427
quant_post ( \
0 commit comments