Skip to content

Commit b18067c

Browse files
Added support for FakeQuantWithMinMaxVars (#1301)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 2e642d3 commit b18067c

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

tests/test_backend.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
floormod = tf.math.floormod
7575
matrix_diag_part = tf.compat.v1.matrix_diag_part
7676
fake_quant_with_min_max_args = tf.quantization.fake_quant_with_min_max_args
77+
fake_quant_with_min_max_vars = tf.quantization.fake_quant_with_min_max_vars
7778
elif LooseVersion(tf.__version__) >= "1.13":
7879
conv2d_backprop_input = tf.compat.v1.nn.conv2d_backprop_input
7980
conv3d_transpose = tf.compat.v1.nn.conv3d_transpose
@@ -95,6 +96,7 @@
9596
floormod = tf.floormod
9697
matrix_diag_part = tf.compat.v1.matrix_diag_part
9798
fake_quant_with_min_max_args = tf.compat.v1.quantization.fake_quant_with_min_max_args
99+
fake_quant_with_min_max_vars = tf.compat.v1.quantization.fake_quant_with_min_max_vars
98100
else:
99101
conv2d_backprop_input = tf.nn.conv2d_backprop_input
100102
conv3d_transpose = tf.nn.conv3d_transpose
@@ -4357,6 +4359,28 @@ def func_neg(x):
43574359
except ValueError:
43584360
pass
43594361

4362+
@check_opset_min_version(10)
4363+
@check_tf_min_version("1.14")
4364+
def test_fakequant_with_min_max_vars(self):
4365+
def func(x):
4366+
ret = fake_quant_with_min_max_vars(
4367+
x, min=-1024, max=1023, num_bits=8, narrow_range=False, name=None)
4368+
return tf.identity(ret, name=_TFOUTPUT)
4369+
4370+
x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024.
4371+
x_val0 = np.abs(x_val)
4372+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val0}, rtol=1e-6, atol=1e-4)
4373+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)
4374+
4375+
x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024
4376+
x_val[0, 0] = -1024
4377+
x_val[0, 1] = -1023
4378+
x_val[0, 2] = 1024
4379+
x_val[1, 0] = 1023
4380+
x_val[1, 1] = 1025
4381+
x_val[1, 2] = -1025
4382+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)
4383+
43604384
@check_opset_min_version(9, "atan2")
43614385
def test_atan2(self):
43624386
# Test all possible pairs of pos, neg, zero for x and y.

tf2onnx/onnx_opset/quantize.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,22 @@
2424
# pylint: disable=unused-argument,missing-docstring,unused-variable,pointless-string-statement,invalid-name
2525

2626

27-
@tf_op("FakeQuantWithMinMaxArgs")
27+
@tf_op(["FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars"])
2828
class FakeQuantWithMinMaxArgs:
2929
# see https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/fake-quant-with-min-max-args
3030
@classmethod
3131
def version_10(cls, ctx, node, **kwargs):
3232
# hack to make up for the missing onnx pack op
33-
amin = node.get_attr("min").f
34-
amax = node.get_attr("max").f
33+
if node.type == "FakeQuantWithMinMaxVars":
34+
utils.make_sure(node.inputs[1].is_scalar(), "%s node %s requires const scalar value for min",
35+
node.type, node.name)
36+
utils.make_sure(node.inputs[2].is_scalar(), "%s node %s requires const scalar value for max",
37+
node.type, node.name)
38+
amin = node.inputs[1].get_tensor_value()
39+
amax = node.inputs[2].get_tensor_value()
40+
else:
41+
amin = node.get_attr("min").f
42+
amax = node.get_attr("max").f
3543
narrow_range = node.get_attr("narrow_range").i
3644
num_bits = node.get_attr("num_bits").i
3745

@@ -58,10 +66,10 @@ def version_10(cls, ctx, node, **kwargs):
5866
zero = np.array(-min_adj, dtype=np.uint8)
5967
make_sure(
6068
zero == -min_adj,
61-
"Cannot convert FakeQuantWithMinMaxArgs with "
69+
"Cannot convert %s node %s with "
6270
"min=%r max=%r numbits=%r because zero_scale=%r "
6371
"is outside uint8 boundary",
64-
amin, amax, num_bits, -min_adj)
72+
node.type, node.name, amin, amax, num_bits, -min_adj)
6573
zero_point = ctx.make_const(
6674
utils.make_name("{}_zpy".format(node.name)), zero)
6775

0 commit comments

Comments
 (0)