@@ -48,9 +48,11 @@ def parse_args():
48
48
parser .add_argument (
49
49
'--qat_model' , type = str , default = '' , help = 'A path to a QAT model.' )
50
50
parser .add_argument (
51
- '--save_model' ,
52
- action = 'store_true' ,
53
- help = 'If used, the QAT model will be saved after all transformations' )
51
+ '--fp32_model' ,
52
+ type = str ,
53
+ default = '' ,
54
+ help = 'A path to an FP32 model. If empty, the QAT model will be used for FP32 inference.'
55
+ )
54
56
parser .add_argument ('--infer_data' , type = str , default = '' , help = 'Data file.' )
55
57
parser .add_argument (
56
58
'--labels' , type = str , default = '' , help = 'File with labels.' )
@@ -239,7 +241,10 @@ def test_graph_transformation(self):
239
241
return
240
242
241
243
qat_model_path = test_case_args .qat_model
244
+ assert qat_model_path , 'The QAT model path cannot be empty. Please, use the --qat_model option.'
245
+ fp32_model_path = test_case_args .fp32_model if test_case_args .fp32_model else qat_model_path
242
246
data_path = test_case_args .infer_data
247
+ assert data_path , 'The dataset path cannot be empty. Please, use the --infer_data option.'
243
248
labels_path = test_case_args .labels
244
249
batch_size = test_case_args .batch_size
245
250
batch_num = test_case_args .batch_num
@@ -250,6 +255,7 @@ def test_graph_transformation(self):
250
255
251
256
_logger .info ('QAT FP32 & INT8 prediction run.' )
252
257
_logger .info ('QAT model: {0}' .format (qat_model_path ))
258
+ _logger .info ('FP32 model: {0}' .format (fp32_model_path ))
253
259
_logger .info ('Dataset: {0}' .format (data_path ))
254
260
_logger .info ('Labels: {0}' .format (labels_path ))
255
261
_logger .info ('Batch size: {0}' .format (batch_size ))
@@ -262,11 +268,12 @@ def test_graph_transformation(self):
262
268
self ._reader_creator (data_path , labels_path ), batch_size = batch_size )
263
269
fp32_acc , fp32_pps , fp32_lat = self ._predict (
264
270
val_reader ,
265
- qat_model_path ,
271
+ fp32_model_path ,
266
272
batch_size ,
267
273
batch_num ,
268
274
skip_batch_num ,
269
275
transform_to_int8 = False )
276
+ _logger .info ('FP32: avg accuracy: {0:.6f}' .format (fp32_acc ))
270
277
_logger .info ('--- QAT INT8 prediction start ---' )
271
278
val_reader = paddle .batch (
272
279
self ._reader_creator (data_path , labels_path ), batch_size = batch_size )
@@ -277,6 +284,7 @@ def test_graph_transformation(self):
277
284
batch_num ,
278
285
skip_batch_num ,
279
286
transform_to_int8 = True )
287
+ _logger .info ('INT8: avg accuracy: {0:.6f}' .format (int8_acc ))
280
288
281
289
self ._summarize_performance (fp32_pps , fp32_lat , int8_pps , int8_lat )
282
290
self ._compare_accuracy (fp32_acc , int8_acc , acc_diff_threshold )
0 commit comments