Skip to content

Commit c177492

Browse files
authored
fix model quantization to support timvx backend (#209)
* fix model quantization to support timvx backend * update table
1 parent 05fe0a4 commit c177492

File tree

5 files changed

+21
-13
lines changed

5 files changed

+21
-13
lines changed

benchmark/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,13 +350,15 @@ Benchmarking ...
350350
backend=cv.dnn.DNN_BACKEND_TIMVX
351351
target=cv.dnn.DNN_TARGET_NPU
352352
mean median min input size model
353+
5.08 4.72 4.70 [160, 120] YuNet with ['face_detection_yunet_2023mar_int8.onnx']
353354
45.83 47.06 43.04 [150, 150] SFace with ['face_recognition_sface_2021dec_int8.onnx']
354355
29.20 27.55 26.25 [112, 112] FacialExpressionRecog with ['facial_expression_recognition_mobilefacenet_2022july_int8.onnx']
355356
18.47 18.16 17.96 [224, 224] MPHandPose with ['handpose_estimation_mediapipe_2023feb_int8.onnx']
356357
28.25 28.35 27.98 [192, 192] PPHumanSeg with ['human_segmentation_pphumanseg_2023mar_int8.onnx']
357358
149.05 155.10 144.42 [224, 224] MobileNet with ['image_classification_mobilenetv1_2022apr_int8.onnx']
358359
147.40 147.49 135.90 [224, 224] MobileNet with ['image_classification_mobilenetv2_2022apr_int8.onnx']
359360
75.91 79.27 71.98 [224, 224] PPResNet with ['image_classification_ppresnet50_2022jan_int8.onnx']
361+
30.98 30.56 29.36 [320, 240] LPD_YuNet with ['license_plate_detection_lpd_yunet_2023mar_int8.onnx']
360362
117.71 119.69 107.37 [416, 416] NanoDet with ['object_detection_nanodet_2022nov_int8.onnx']
361363
379.46 366.19 360.02 [640, 640] YoloX with ['object_detection_yolox_2022nov_int8.onnx']
362364
33.90 36.32 31.71 [192, 192] MPPalmDet with ['palm_detection_mediapipe_2023feb_int8.onnx']

benchmark/color_table.svg

Lines changed: 4 additions & 4 deletions
Loading
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:ba0af078d0001f4f91cd74bf8bf78bacdb04e0b6cfa00b02bf258c30d57a0483
3-
size 99673
2+
oid sha256:321aa5a6afabf7ecc46a3d06bfab2b579dc96eb5c3be7edd365fa04502ad9294
3+
size 100416
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:c1d64b3d0e5a8470cfba63ea6dc8752188bfdca733836aea22a2310bef836e5c
3-
size 1087530
2+
oid sha256:d67982a014fe93ad04612f565ed23ca010dcb0fd925d880ef0edf9cd7bdf931a
3+
size 1087142

tools/quantize/quantize-ort.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get_calibration_data(self, image_dir):
4646
return blobs
4747

4848
class Quantize:
49-
def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_channel=False, act_type='int8', wt_type='int8', data_dim='chw'):
49+
def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_channel=False, act_type='int8', wt_type='int8', data_dim='chw', nodes_to_exclude=[]):
5050
self.type_dict = {"uint8" : QuantType.QUInt8, "int8" : QuantType.QInt8}
5151

5252
self.model_path = model_path
@@ -55,6 +55,7 @@ def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_
5555
self.per_channel = per_channel
5656
self.act_type = act_type
5757
self.wt_type = wt_type
58+
self.nodes_to_exclude = nodes_to_exclude
5859

5960
# data reader
6061
self.dr = DataReader(self.model_path, self.calibration_image_dir, self.transforms, data_dim)
@@ -80,15 +81,18 @@ def run(self):
8081
quant_format=QuantFormat.QOperator, # start from onnxruntime==1.11.0, quant_format is set to QuantFormat.QDQ by default, which performs fake quantization
8182
per_channel=self.per_channel,
8283
weight_type=self.type_dict[self.wt_type],
83-
activation_type=self.type_dict[self.act_type])
84+
activation_type=self.type_dict[self.act_type],
85+
nodes_to_exclude=self.nodes_to_exclude)
8486
if new_model_path != self.model_path:
8587
os.remove(new_model_path)
8688
print('\tQuantized model saved to {}'.format(output_name))
8789

8890
models=dict(
89-
yunet=Quantize(model_path='../../models/face_detection_yunet/face_detection_yunet_2022mar.onnx',
91+
yunet=Quantize(model_path='../../models/face_detection_yunet/face_detection_yunet_2023mar.onnx',
9092
calibration_image_dir='../../benchmark/data/face_detection',
91-
transforms=Compose([Resize(size=(160, 120))])),
93+
transforms=Compose([Resize(size=(160, 120))]),
94+
nodes_to_exclude=['MaxPool_5', 'MaxPool_18', 'MaxPool_25', 'MaxPool_32'],
95+
),
9296
sface=Quantize(model_path='../../models/face_recognition_sface/face_recognition_sface_2021dec.onnx',
9397
calibration_image_dir='../../benchmark/data/face_recognition',
9498
transforms=Compose([Resize(size=(112, 112))])),
@@ -119,7 +123,9 @@ def run(self):
119123
ColorConvert(ctype=cv.COLOR_BGR2RGB)]), data_dim='hwc'),
120124
lpd_yunet=Quantize(model_path='../../models/license_plate_detection_yunet/license_plate_detection_lpd_yunet_2023mar.onnx',
121125
calibration_image_dir='../../benchmark/data/license_plate_detection',
122-
transforms=Compose([Resize(size=(320, 240))])),
126+
transforms=Compose([Resize(size=(320, 240))]),
127+
nodes_to_exclude=['MaxPool_5', 'MaxPool_18', 'MaxPool_25', 'MaxPool_32', 'MaxPool_39'],
128+
),
123129
)
124130

125131
if __name__ == '__main__':

0 commit comments

Comments
 (0)