Skip to content

Commit 5373991

Browse files
committed
update quantize operator
1 parent 8dd3aee commit 5373991

File tree

1 file changed

+46
-26
lines changed

1 file changed

+46
-26
lines changed

tf2onnx/onnx_opset/quantize.py

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,35 +27,55 @@
2727
# pylint: disable=unused-argument,missing-docstring,unused-variable,pointless-string-statement,invalid-name
2828

2929

30-
@tf_op("FakeQuantWithMinMaxVars")
31-
class FakeQuantWithMinMaxVars:
30+
@tf_op("FakeQuantWithMinMaxArgs")
31+
class FakeQuantWithMinMaxArgs:
32+
# see https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/fake-quant-with-min-max-args
3233
@classmethod
3334
def version_11(cls, ctx, node, **kwargs):
3435
# hack to make up for the missing onnx pack op
35-
import pprint
36-
pprint.pprint(node)
37-
amin = node.get_attr("min").i
38-
if axis < 0:
39-
axis += len(ctx.get_shape(node.input[0])) + 1
40-
41-
inputs = []
42-
dtype = None
43-
# insert Unsqueeze on each input
44-
for i, n in enumerate(node.inputs):
45-
dtype = ctx.get_dtype(node.input[i])
46-
shape = ctx.get_shape(node.input[i])
47-
new_node = ctx.make_node("Unsqueeze", [node.input[i]], op_name_scope=node.name, attr={"axes": [axis]},
48-
shapes=[shape], dtypes=[dtype])
49-
output_name = new_node.output[0]
50-
node.input[i] = output_name
51-
inputs.append(output_name)
52-
53-
shapes = node.output_shapes
54-
dtypes = node.output_dtypes
36+
amin = node.get_attr("min").f
37+
amax = node.get_attr("max").f
38+
narrow_range = node.get_attr("narrow_range").i
39+
num_bits = node.get_attr("num_bits").i
40+
41+
if narrow_range:
42+
raise RuntimeError(
43+
"Unable to convert node FakeQuantWithMinMaxArgs with "
44+
"narrow_range=%r" % narrow_range)
45+
46+
if 0 < amin < amax:
47+
min_adj = 0
48+
max_adj = amax - amin
49+
scale = 1.
50+
elif amin < amax < 0:
51+
min_adj = amin - amax
52+
max_adj = 0
53+
scale = 1.
54+
elif amin <= 0 <= amax:
55+
scale = (amax - amin) / (2 ** num_bits - 1)
56+
min_adj = scale * int(amin / scale)
57+
max_adj = amax + min_adj - amin
58+
else:
59+
raise RuntimeError(
60+
"Unable to convert node FakeQuantWithMinMaxArgs with "
61+
"min=%f and max=%f" % (amin, amax))
62+
63+
dtype = ctx.get_dtype(node.input[0])
64+
shape = ctx.get_shape(node.input[0])
65+
66+
new_node = ctx.make_node(
67+
"QuantizeLinear", [node.input[0], pb_scale, y_zero_point],
68+
op_name_scope=node.name, attr={"axes": [axis]},
69+
shapes=[shape], dtypes=[idtype])
70+
output_name = new_node.output[0]
71+
node.input[i] = output_name
72+
5573
ctx.remove_node(node.name)
56-
# concat all unqueezes
57-
concat = ctx.make_node("Concat", inputs, op_name_scope=node.name, attr={"axis": axis},
58-
shapes=shapes, dtypes=dtypes)
59-
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], concat.output[0])
74+
75+
last_node = ctx.make_node(
76+
"DequantizeLinear", [new_node.output[0], x_scale, x_zero_point],
77+
op_name_scope=node.name, attr={"axis": axis},
78+
shapes=[shape], dtypes=[dtype])
79+
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], last_node.output[0])
6080

6181

0 commit comments

Comments
 (0)