Skip to content

Commit 9ba5216

Browse files
author
Wojciech Uss
authored
Added an option to use external FP32 model in QAT comparison test (#22858) (#22873)
1 parent 781f8b2 commit 9ba5216

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

python/paddle/fluid/contrib/slim/tests/qat_int8_nlp_comparison.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,11 @@ def parse_args():
4848
parser.add_argument(
4949
'--qat_model', type=str, default='', help='A path to a QAT model.')
5050
parser.add_argument(
51-
'--save_model',
52-
action='store_true',
53-
help='If used, the QAT model will be saved after all transformations')
51+
'--fp32_model',
52+
type=str,
53+
default='',
54+
help='A path to an FP32 model. If empty, the QAT model will be used for FP32 inference.'
55+
)
5456
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
5557
parser.add_argument(
5658
'--labels', type=str, default='', help='File with labels.')
@@ -239,7 +241,10 @@ def test_graph_transformation(self):
239241
return
240242

241243
qat_model_path = test_case_args.qat_model
244+
assert qat_model_path, 'The QAT model path cannot be empty. Please, use the --qat_model option.'
245+
fp32_model_path = test_case_args.fp32_model if test_case_args.fp32_model else qat_model_path
242246
data_path = test_case_args.infer_data
247+
assert data_path, 'The dataset path cannot be empty. Please, use the --infer_data option.'
243248
labels_path = test_case_args.labels
244249
batch_size = test_case_args.batch_size
245250
batch_num = test_case_args.batch_num
@@ -250,6 +255,7 @@ def test_graph_transformation(self):
250255

251256
_logger.info('QAT FP32 & INT8 prediction run.')
252257
_logger.info('QAT model: {0}'.format(qat_model_path))
258+
_logger.info('FP32 model: {0}'.format(fp32_model_path))
253259
_logger.info('Dataset: {0}'.format(data_path))
254260
_logger.info('Labels: {0}'.format(labels_path))
255261
_logger.info('Batch size: {0}'.format(batch_size))
@@ -262,11 +268,12 @@ def test_graph_transformation(self):
262268
self._reader_creator(data_path, labels_path), batch_size=batch_size)
263269
fp32_acc, fp32_pps, fp32_lat = self._predict(
264270
val_reader,
265-
qat_model_path,
271+
fp32_model_path,
266272
batch_size,
267273
batch_num,
268274
skip_batch_num,
269275
transform_to_int8=False)
276+
_logger.info('FP32: avg accuracy: {0:.6f}'.format(fp32_acc))
270277
_logger.info('--- QAT INT8 prediction start ---')
271278
val_reader = paddle.batch(
272279
self._reader_creator(data_path, labels_path), batch_size=batch_size)
@@ -277,6 +284,7 @@ def test_graph_transformation(self):
277284
batch_num,
278285
skip_batch_num,
279286
transform_to_int8=True)
287+
_logger.info('INT8: avg accuracy: {0:.6f}'.format(int8_acc))
280288

281289
self._summarize_performance(fp32_pps, fp32_lat, int8_pps, int8_lat)
282290
self._compare_accuracy(fp32_acc, int8_acc, acc_diff_threshold)

0 commit comments

Comments
 (0)