Skip to content

Commit e3bc51c

Browse files
committed
fix operator for fakequantize
1 parent 5373991 commit e3bc51c

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

tests/test_backend.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3357,7 +3357,7 @@ def func(base_matrix, diag, k):
33573357
@check_opset_min_version(10)
33583358
@check_tf_min_version("1.14")
33593359
def test_fakequant_with_min_max(self):
3360-
x_val = np.random.random(size=[4, 5]).astype(np.float32) * 2048. - 1024.
3360+
x_val = np.random.random(size=[3, 3]).astype(np.float32) * 2048. - 1024.
33613361
def func(x):
33623362
ret = fake_quant_with_min_max_args(
33633363
x, min=-1024, max=1024, num_bits=8, narrow_range=False, name=None)
@@ -3366,7 +3366,7 @@ def func(x):
33663366

33673367

33683368
if __name__ == '__main__':
3369-
cl = BackendTests()
3370-
cl.setUp()
3371-
cl.test_fakequant_with_min_max()
3369+
# cl = BackendTests()
3370+
# cl.setUp()
3371+
# cl.test_fakequant_with_min_max()
33723372
unittest_main()

tf2onnx/onnx_opset/quantize.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import sys
1414

1515
import numpy as np
16-
from onnx import onnx_pb
16+
from onnx import onnx_pb, numpy_helper, TensorProto
1717
from onnx.onnx_pb import TensorProto
1818

1919
from tf2onnx import constants, utils
@@ -37,12 +37,19 @@ def version_11(cls, ctx, node, **kwargs):
3737
amax = node.get_attr("max").f
3838
narrow_range = node.get_attr("narrow_range").i
3939
num_bits = node.get_attr("num_bits").i
40-
40+
4141
if narrow_range:
4242
raise RuntimeError(
4343
"Unable to convert node FakeQuantWithMinMaxArgs with "
4444
"narrow_range=%r" % narrow_range)
45-
45+
if num_bits != 8:
46+
raise RuntimeError(
47+
"Unable to convert node FakeQuantWithMinMaxArgs with "
48+
"num_bits=%r" % num_bits)
49+
50+
scale = (amax - amin) / (2 ** num_bits - 1)
51+
min_adj = scale * int(amin / scale)
52+
max_adj = amax + min_adj - amin
4653
if 0 < amin < amax:
4754
min_adj = 0
4855
max_adj = amax - amin
@@ -62,18 +69,27 @@ def version_11(cls, ctx, node, **kwargs):
6269

6370
dtype = ctx.get_dtype(node.input[0])
6471
shape = ctx.get_shape(node.input[0])
72+
axis = 1
73+
idtype = TensorProto.UINT8
74+
75+
pb_scale = ctx.make_const(
76+
utils.make_name("{}_scaley".format(node.name)),
77+
np.array(scale, dtype=np.float32))
78+
zero_point = ctx.make_const(
79+
utils.make_name("{}_zpy".format(node.name)),
80+
np.array(min_adj, dtype=np.uint8))
6581

6682
new_node = ctx.make_node(
67-
"QuantizeLinear", [node.input[0], pb_scale, y_zero_point],
68-
op_name_scope=node.name, attr={"axes": [axis]},
83+
"QuantizeLinear", [node.input[0], pb_scale.name, zero_point.name],
84+
op_name_scope=node.name, attr={"axis": axis},
6985
shapes=[shape], dtypes=[idtype])
7086
output_name = new_node.output[0]
71-
node.input[i] = output_name
87+
node.input[0] = output_name
7288

7389
ctx.remove_node(node.name)
7490

7591
last_node = ctx.make_node(
76-
"DequantizeLinear", [new_node.output[0], x_scale, x_zero_point],
92+
"DequantizeLinear", [new_node.output[0], pb_scale.name, zero_point.name],
7793
op_name_scope=node.name, attr={"axis": axis},
7894
shapes=[shape], dtypes=[dtype])
7995
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], last_node.output[0])

0 commit comments

Comments
 (0)