Skip to content

Commit bcde434

Browse files
committed
fix fake_quant
1 parent 4ba1407 commit bcde434

File tree

2 files changed

+36
-31
lines changed

2 files changed

+36
-31
lines changed

tests/test_backend.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3359,23 +3359,37 @@ def func(base_matrix, diag, k):
33593359
def test_fakequant_with_min_max(self):
33603360
def func(x):
33613361
ret = fake_quant_with_min_max_args(
3362-
x, min=-1024, max=1024, num_bits=8, narrow_range=False, name=None)
3362+
x, min=-1024, max=1023, num_bits=8, narrow_range=False, name=None)
33633363
return tf.identity(ret, name=_TFOUTPUT)
33643364

33653365
x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024.
33663366
x_val0 = np.abs(x_val)
3367-
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val0})
3368-
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3367+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val0}, rtol=1e-6, atol=1e-4)
3368+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)
3369+
3370+
x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024
3371+
x_val[0, 0] = -1024
3372+
x_val[0, 1] = -1023
3373+
x_val[0, 2] = 1024
3374+
x_val[1, 0] = 1023
3375+
x_val[1, 1] = 1025
3376+
x_val[1, 2] = -1025
3377+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)
3378+
3379+
@check_opset_min_version(10)
3380+
@check_tf_min_version("1.14")
3381+
def test_fakequant_with_min_max_same_sign(self):
3382+
def func_neg(x):
3383+
ret = fake_quant_with_min_max_args(
3384+
x, min=-1024*3, max=-1024, num_bits=8, narrow_range=False, name=None)
3385+
return tf.identity(ret, name=_TFOUTPUT)
3386+
3387+
x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024 * 3.
3388+
try:
3389+
self._run_test_case(func_neg, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)
3390+
except RuntimeError:
3391+
pass
33693392

33703393

33713394
if __name__ == '__main__':
3372-
#cl = BackendTests()
3373-
#cl.setUp()
3374-
#cl.test_fakequant_with_min_max()
3375-
#import cProfile
3376-
#cProfile.run('unittest_main()', 'restats')
33773395
unittest_main()
3378-
#import pstats
3379-
#from pstats import SortKey
3380-
#p = pstats.Stats('restats')
3381-
#p.sort_stats(SortKey.CUMULATIVE).print_stats()

tf2onnx/onnx_opset/quantize.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,8 @@ def version_11(cls, ctx, node, **kwargs):
4444
"Unable to convert node FakeQuantWithMinMaxArgs with "
4545
"num_bits=%r" % num_bits)
4646

47-
if 0 < amin < amax:
48-
min_adj = 0
49-
max_adj = amax - amin
50-
scale = 1.
51-
elif amin < amax < 0:
52-
min_adj = amin - amax
53-
max_adj = 0
54-
scale = 1.
55-
elif amin <= 0 <= amax:
56-
scale = (amax - amin) / (2 ** num_bits - 1)
57-
min_adj = scale * int(amin / scale)
58-
max_adj = amax + min_adj - amin
59-
else:
60-
raise RuntimeError(
61-
"Unable to convert node FakeQuantWithMinMaxArgs with "
62-
"min=%f and max=%f" % (amin, amax))
47+
scale = (amax - amin) / (2 ** num_bits - 1)
48+
min_adj = np.around(amin / scale)
6349

6450
dtype = ctx.get_dtype(node.input[0])
6551
shape = ctx.get_shape(node.input[0])
@@ -69,9 +55,15 @@ def version_11(cls, ctx, node, **kwargs):
6955
pb_scale = ctx.make_const(
7056
utils.make_name("{}_scaley".format(node.name)),
7157
np.array(scale, dtype=np.float32))
58+
zero = np.array(-min_adj, dtype=np.uint8)
59+
if zero != -min_adj:
60+
raise RuntimeError(
61+
"Cannot convert FakeQuantWithMinMaxArgs with "
62+
"min={} max={} numbits={} because zero_scale={} "
63+
"is outside uint8 boundary".format(
64+
amin, amax, num_bits, -min_adj))
7265
zero_point = ctx.make_const(
73-
utils.make_name("{}_zpy".format(node.name)),
74-
np.array(min_adj, dtype=np.uint8))
66+
utils.make_name("{}_zpy".format(node.name)), zero)
7567

7668
new_node = ctx.make_node(
7769
"QuantizeLinear", [node.input[0], pb_scale.name, zero_point.name],
@@ -87,4 +79,3 @@ def version_11(cls, ctx, node, **kwargs):
8779
op_name_scope=node.name, attr={"axis": axis},
8880
shapes=[shape], dtypes=[dtype])
8981
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], last_node.output[0])
90-

0 commit comments

Comments
 (0)