Skip to content

Commit 8aa1127

Browse files
Tom/tflite dequantize (#1322)
* Implement const dequantize pushing for per-axis dequantization Signed-off-by: Tom Wildenhain <[email protected]> * Add --dequantize flag Signed-off-by: Tom Wildenhain <[email protected]> * pylint fixes Signed-off-by: Tom Wildenhain <[email protected]> * Update readme Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 4d83a0e commit 8aa1127

File tree

10 files changed

+212
-31
lines changed

10 files changed

+212
-31
lines changed

README.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,12 +130,14 @@ You find an end-to-end tutorial for ssd-mobilenet [here](tutorials/ConvertingSSD
130130
python -m tf2onnx.convert
131131
--saved-model SOURCE_SAVED_MODEL_PATH |
132132
--checkpoint SOURCE_CHECKPOINT_METAFILE_PATH |
133+
--tflite SOURCE_TFLITE_PATH |
133134
--input | --graphdef SOURCE_GRAPHDEF_PB
134135
--output TARGET_ONNX_MODEL
135136
[--inputs GRAPH_INPUTS]
136137
[--outputs GRAPH_OUTPUS]
137138
[--inputs-as-nchw inputs_provided_as_nchw]
138139
[--opset OPSET]
140+
[--dequantize]
139141
[--tag TAG]
140142
[--signature_def SIGNATURE_DEF]
141143
[--concrete_function CONCRETE_FUNCTION]
@@ -158,6 +160,12 @@ TensorFlow model as saved_model. We expect the path to the saved_model directory
158160

159161
TensorFlow model as checkpoint. We expect the path to the .meta file.
160162

163+
#### --tflite
164+
165+
(This is experimental)
166+
167+
Convert a tflite model by providing a path to the .tflite file. Inputs/outputs do not need to be specified.
168+
161169
#### --input or --graphdef
162170

163171
TensorFlow model as graphdef file.
@@ -182,6 +190,12 @@ ONNX requires default values for graph inputs to be constant, while Tensorflow's
182190

183191
By default we use the opset 8 to generate the graph. By specifying ```--opset``` the user can override the default to generate a graph with the desired opset. For example ```--opset 5``` would create a onnx graph that uses only ops available in opset 5. Because older opsets have in most cases fewer ops, some models might not convert on a older opset.
184192

193+
#### --dequantize
194+
195+
(This is experimental, only supported for tflite)
196+
197+
Produces a float32 model from a quantized tflite model. Detects ReLU and ReLU6 ops from quantization bounds.
198+
185199
#### --tag
186200

187201
Only valid with parameter `--saved_model`. Specifies the tag in the saved_model to be used. Typical value is 'serve'.

tests/ade20k.jpg

92.9 KB
Loading

tests/run_pretrained_models.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from __future__ import unicode_literals
1010

1111
# pylint: disable=broad-except,logging-not-lazy,unused-argument,unnecessary-lambda,import-outside-toplevel
12-
# pylint: disable=wrong-import-position
12+
# pylint: disable=wrong-import-position,too-many-nested-blocks
1313

1414
import argparse
1515
import os
@@ -79,6 +79,16 @@ def get_car(shape):
7979
return get_img(shape, "car.JPEG", np.float32, should_scale=True)
8080

8181

82+
def get_ade20k(shape):
83+
"""Get truck image from ade20k segmentation dataset."""
84+
return get_img(shape, "ade20k.jpg", np.float32, should_scale=True)
85+
86+
87+
def get_ade20k_uint8(shape):
88+
"""Get truck image from ade20k segmentation dataset."""
89+
return get_img(shape, "ade20k.jpg", np.uint8, should_scale=False)
90+
91+
8292
def get_random(shape):
8393
"""Get random input."""
8494
np.random.seed(42)
@@ -146,6 +156,8 @@ def get_sentence():
146156
_INPUT_FUNC_MAPPING = {
147157
"get_beach": get_beach,
148158
"get_car": get_car,
159+
"get_ade20k": get_ade20k,
160+
"get_ade20k_uint8": get_ade20k_uint8,
149161
"get_random": get_random,
150162
"get_random256": get_random256,
151163
"get_ramp": get_ramp,
@@ -171,7 +183,7 @@ class Test(object):
171183
target = []
172184

173185
def __init__(self, url, local, input_func, input_names, output_names,
174-
disabled=False, rtol=0.01, atol=1e-6,
186+
disabled=False, rtol=0.01, atol=1e-6, ptol=0, dequantize=False,
175187
check_only_shape=False, model_type="frozen", force_input_shape=False,
176188
skip_tensorflow=False, opset_constraints=None, tf_min_version=None, tag=None,
177189
skip_conversion=False, converted_model=None, signature_def=None, concrete_function=None,
@@ -190,6 +202,8 @@ def __init__(self, url, local, input_func, input_names, output_names,
190202
self.structured_outputs = structured_outputs # Needed to determine output order for tf_function
191203
self.rtol = rtol
192204
self.atol = atol
205+
self.ptol = ptol
206+
self.dequantize = dequantize
193207
self.check_only_shape = check_only_shape
194208
self.perf = None
195209
self.tf_runtime = 0
@@ -292,7 +306,7 @@ def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, i
292306
extra_opset=extra_opset, target=Test.target, shape_override=shape_override,
293307
input_names=input_names, output_names=self.output_names,
294308
const_node_values=const_node_values, initialized_tables=initialized_tables,
295-
tflite_path=tflite_path)
309+
tflite_path=tflite_path, dequantize=self.dequantize)
296310

297311
def run_caffe2(self, name, model_proto, inputs):
298312
"""Run test again caffe2 backend."""
@@ -531,7 +545,11 @@ def run_tflite():
531545
np.testing.assert_array_equal(tf_res.shape, onnx_res.shape)
532546
else:
533547
for tf_res, onnx_res in zip(tf_results, onnx_results):
534-
np.testing.assert_allclose(tf_res, onnx_res, rtol=self.rtol, atol=self.atol)
548+
good_cnt = np.count_nonzero(np.isclose(tf_res, onnx_res, rtol=self.rtol, atol=self.atol))
549+
bad_cnt = tf_res.size - good_cnt
550+
if bad_cnt > self.ptol / 100 * tf_res.size:
551+
# Prints a nice error message with stats
552+
np.testing.assert_allclose(tf_res, onnx_res, rtol=self.rtol, atol=self.atol)
535553
logger.info("Results: OK")
536554
return True
537555
except Exception:
@@ -658,10 +676,10 @@ def load_tests_from_yaml(path):
658676
opset_constraints.append(c)
659677

660678
kwargs = {}
661-
for kw in ["rtol", "atol", "disabled", "check_only_shape", "model_type", "concrete_function",
679+
for kw in ["rtol", "atol", "ptol", "disabled", "check_only_shape", "model_type", "concrete_function",
662680
"skip_tensorflow", "force_input_shape", "tf_min_version", "tag", "skip_conversion",
663681
"converted_model", "signature_def", "large_model", "structured_outputs", "run_tf_frozen",
664-
"use_custom_ops"]:
682+
"use_custom_ops", "dequantize"]:
665683
if settings.get(kw) is not None:
666684
kwargs[kw] = settings[kw]
667685

tests/run_pretrained_models.yaml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,3 +454,30 @@ ssd_mobilenet_v2_300_float_tflite:
454454
- TFLite_Detection_PostProcess:1
455455
- TFLite_Detection_PostProcess:2
456456
- TFLite_Detection_PostProcess:3
457+
458+
deeplabv3_mnv2_ade20k_float_tflite:
459+
tf_min_version: 2.1
460+
disabled: false
461+
url: https://github.com/mlcommons/mobile_models/raw/main/v0_7/tflite/deeplabv3_mnv2_ade20k_float.tflite
462+
model: "deeplabv3_mnv2_ade20k_float.tflite"
463+
model_type: tflite
464+
input_get: get_ade20k
465+
ptol: 0.001
466+
inputs:
467+
"MobilenetV2/MobilenetV2/input": [1, 512, 512, 3]
468+
outputs:
469+
- ArgMax
470+
471+
deeplabv3_mnv2_ade20k_uint8_tflite:
472+
tf_min_version: 2.1
473+
disabled: false
474+
url: https://github.com/mlcommons/mobile_models/raw/main/v0_7/tflite/deeplabv3_mnv2_ade20k_uint8.tflite
475+
model: "deeplabv3_mnv2_ade20k_uint8.tflite"
476+
model_type: tflite
477+
input_get: get_ade20k_uint8
478+
ptol: 1.0
479+
dequantize: true
480+
inputs:
481+
"MobilenetV2/MobilenetV2/input": [1, 512, 512, 3]
482+
outputs:
483+
- ArgMax

tf2onnx/convert.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def get_args():
6565
parser.add_argument("--use_default", help="comma-separated list of names of PlaceholderWithDefault ops to "
6666
"change into Identity ops using their default value")
6767
parser.add_argument("--opset", type=int, default=None, help="opset version to use for onnx domain")
68+
parser.add_argument("--dequantize", help="Remove quantization from model. Only supported for tflite currently.",
69+
action="store_true")
6870
parser.add_argument("--custom-ops", help="comma-separated map of custom ops to domains in format OpName:domain")
6971
parser.add_argument("--extra_opset", default=None,
7072
help="extra opset with format like domain:version, e.g. com.microsoft:1")
@@ -104,6 +106,9 @@ def get_args():
104106
args.target = args.target.split(",")
105107
if args.signature_def:
106108
args.signature_def = [args.signature_def]
109+
if args.dequantize:
110+
if not args.tflite:
111+
parser.error("dequantize flag is currently only supported for tflite")
107112
if args.extra_opset:
108113
tokens = args.extra_opset.split(':')
109114
if len(tokens) != 2:
@@ -202,7 +207,8 @@ def main():
202207
use_default=args.use_default,
203208
const_node_values=const_node_values,
204209
initialized_tables=initialized_tables,
205-
tflite_path=tflite_path)
210+
tflite_path=tflite_path,
211+
dequantize=args.dequantize)
206212

207213
onnx_graph = optimizer.optimize_graph(g)
208214

tf2onnx/tflite_handlers/tfl_math.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import logging
99
import numpy as np
10+
from onnx.onnx_pb import TensorProto
1011
from tf2onnx.handler import tfl_op
1112
from tf2onnx import utils
1213

@@ -87,31 +88,76 @@ def to_tf(cls, ctx, node, **kwargs):
8788

8889
@tfl_op(["TFL_QUANTIZE"], onnx_op="QuantizeLinear")
8990
class TflQuantizeOp:
91+
@classmethod
92+
def version_1(cls, ctx, node, dequantize=False, **kwargs):
93+
# We could just let the TFL_QUANTIZE fall through as an unconverted op, but they are added programmatically
94+
# so that might be confusing.
95+
raise ValueError("Opset 10 is required for quantization. Consider using the --dequantize flag or --opset 10.")
96+
9097
@classmethod
9198
def version_10(cls, ctx, node, **kwargs):
9299
scale = node.get_attr_value('scale')
93100
zero_point = node.get_attr_value('zero_point')
94101
axis = node.get_attr_value('quantized_dimension')
95102
np_q_type = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.output[0]))
96103
if len(scale) > 1 or len(zero_point) > 1:
104+
utils.make_sure(ctx.opset >= 13, "Opset 13 is required for per-axis quantization for node %s", node.name)
97105
node.set_attr("axis", axis)
98106
scale_node = ctx.make_const(utils.make_name("scale"), np.array(scale[0], dtype=np.float32))
99107
zero_point_node = ctx.make_const(utils.make_name("zero_point"), np.array(zero_point[0], dtype=np_q_type))
100108
ctx.replace_inputs(node, [node.input[0], scale_node.output[0], zero_point_node.output[0]])
101109
del node.attr["scale"]
102110
del node.attr["zero_point"]
103111
del node.attr["quantized_dimension"]
112+
if "min" in node.attr:
113+
del node.attr["min"]
114+
if "max" in node.attr:
115+
del node.attr["max"]
104116

105117
@tfl_op(["TFL_DEQUANTIZE"], onnx_op="DequantizeLinear")
106118
class TflDequantizeOp:
107119
@classmethod
108-
def version_10(cls, ctx, node, **kwargs):
120+
def version_1(cls, ctx, node, **kwargs):
121+
scale = np.array(node.get_attr_value('scale'), dtype=np.float32)
122+
zero_point = np.array(node.get_attr_value('zero_point'), dtype=np.float32)
123+
axis = node.get_attr_value('quantized_dimension')
124+
in_rank = ctx.get_rank(node.input[0])
125+
def expand_tensor(t):
126+
if t.shape == (1,):
127+
return t[0]
128+
utils.make_sure(in_rank is not None, "Cannot dequantize node %s with unknown input rank", node.name)
129+
new_shape = [1] * in_rank
130+
new_shape[axis] = t.shape[0]
131+
return t.reshape(new_shape)
132+
scale = expand_tensor(scale)
133+
zero_point = expand_tensor(zero_point)
134+
if node.inputs[0].is_const():
135+
x_val = node.inputs[0].get_tensor_value(as_list=False).astype(np.float32)
136+
new_val = (x_val - zero_point) * scale
137+
dequant_const = ctx.make_const(utils.make_name(node.name), new_val)
138+
ctx.replace_all_inputs(node.output[0], dequant_const.output[0])
139+
ctx.remove_node(node.name)
140+
else:
141+
scale_const = ctx.make_const(utils.make_name(node.name + "_scale"), scale).output[0]
142+
zero_point_const = ctx.make_const(utils.make_name(node.name + "_zero_point"), zero_point).output[0]
143+
cast_node = ctx.make_node("Cast", [node.input[0]], attr={'to': TensorProto.FLOAT},
144+
op_name_scope=node.name).output[0]
145+
sub_node = ctx.make_node("Sub", [cast_node, zero_point_const], op_name_scope=node.name).output[0]
146+
mul_node = ctx.make_node("Mul", [sub_node, scale_const], op_name_scope=node.name).output[0]
147+
ctx.replace_all_inputs(node.output[0], mul_node)
148+
ctx.remove_node(node.name)
149+
150+
@classmethod
151+
def version_10(cls, ctx, node, dequantize=False, **kwargs):
152+
if dequantize:
153+
cls.version_1(ctx, node, dequantize=True, **kwargs)
154+
return
109155
scale = node.get_attr_value('scale')
110156
zero_point = node.get_attr_value('zero_point')
111157
axis = node.get_attr_value('quantized_dimension')
112158
np_q_type = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[0]))
113159
if len(scale) > 1 or len(zero_point) > 1:
114-
utils.make_sure(ctx.opset >= 13, "Opset 13 is required for per-axis quantization")
160+
utils.make_sure(ctx.opset >= 13, "Opset 13 is required for per-axis quantization for node %s", node.name)
115161
node.set_attr("axis", axis)
116162
scale_node = ctx.make_const(utils.make_name("scale"), np.array(scale, dtype=np.float32))
117163
zero_point_node = ctx.make_const(utils.make_name("zero_point"), np.array(zero_point, dtype=np_q_type))
@@ -122,6 +168,10 @@ def version_10(cls, ctx, node, **kwargs):
122168
del node.attr["scale"]
123169
del node.attr["zero_point"]
124170
del node.attr["quantized_dimension"]
171+
if "min" in node.attr:
172+
del node.attr["min"]
173+
if "max" in node.attr:
174+
del node.attr["max"]
125175

126176
def dynamic_quantize_inputs(ctx, node):
127177
if ctx.opset < 11:

tf2onnx/tflite_rewriters/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
"""tf2onnx.tflite_rewriters module"""
44

55
from tf2onnx.tflite_rewriters.tfl_scan_output_rewriter import rewrite_tfl_scan_outputs
6+
from tf2onnx.tflite_rewriters.tfl_qdq_rewriter import rewrite_tfl_qdq
67

78
__all__ = [
89
"rewrite_tfl_scan_outputs",
10+
"rewrite_tfl_qdq"
911
]
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""
5+
tf2onnx.tflite_rewriters.tfl_qdq_rewriter - Remove qdq sequences to dequantize model
6+
"""
7+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
8+
9+
10+
# pylint: disable=missing-docstring
11+
12+
def rewrite_tfl_qdq(g, ops):
13+
pattern0 = \
14+
OpTypePattern('TFL_DEQUANTIZE', name='dequant', inputs=[
15+
OpTypePattern('TFL_QUANTIZE', name='quant'),
16+
])
17+
18+
matcher = GraphMatcher(pattern0, allow_reorder=False)
19+
match_results = list(matcher.match_ops(ops))
20+
if match_results:
21+
for match in match_results:
22+
dequant = match.get_op("dequant")
23+
quant = match.get_op("quant")
24+
inp_node = quant.inputs[0]
25+
for k in ["scale", "quantized_dimension", "zero_point"]:
26+
if dequant.get_attr_value(k) != quant.get_attr_value(k):
27+
continue
28+
needed_relu = None
29+
if all(k in quant.attr and len(quant.get_attr_value(k)) == 1 for k in ["min", "max"]):
30+
min_val = quant.get_attr_value("min")[0]
31+
max_val = quant.get_attr_value("max")[0]
32+
if min_val == 0.0 and 5.999 <= max_val <= 6.0:
33+
needed_relu = "TFL_RELU6"
34+
elif min_val == 0.0:
35+
# This may introduce unneeded relu ops but will be correct.
36+
# If the --dequantize feature is used a lot in the future we can optimize this.
37+
needed_relu = "TFL_RELU"
38+
if inp_node.type == needed_relu:
39+
# If it's really obviously unneeded, we skip it.
40+
needed_relu = None
41+
elif "TFL_" + inp_node.get_attr_value("fused_activation_function", b'').decode() == needed_relu:
42+
needed_relu = None
43+
44+
if needed_relu is not None:
45+
relu_name = inp_node.name + "_relu"
46+
47+
relu6 = g.make_node(needed_relu, [quant.input[0]], op_name_scope=relu_name,
48+
skip_conversion=False, shapes=quant.output_shapes, dtypes=quant.output_dtypes)
49+
g.replace_all_inputs(dequant.output[0], relu6.output[0])
50+
else:
51+
g.replace_all_inputs(dequant.output[0], quant.input[0])
52+
53+
g.remove_node(dequant.name)
54+
if len(g.find_output_consumers(quant.output[0])) == 0:
55+
g.remove_node(quant.name)
56+
57+
return ops

0 commit comments

Comments
 (0)