Skip to content

Commit 38b1a6a

Browse files
authored
Merge pull request #1009 from xadupre/i500qu
Fixes #500, replaces operator FakeQuantWithMinMaxVars
2 parents 0b0c5ab + d88ff8e commit 38b1a6a

File tree

3 files changed

+120
-1
lines changed

3 files changed

+120
-1
lines changed

tests/test_backend.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
is_inf = tf.math.is_inf
7171
floormod = tf.math.floormod
7272
matrix_diag_part = tf.compat.v1.matrix_diag_part
73+
fake_quant_with_min_max_args = tf.quantization.fake_quant_with_min_max_args
7374
elif LooseVersion(tf.__version__) >= "1.13":
7475
conv2d_backprop_input = tf.compat.v1.nn.conv2d_backprop_input
7576
multinomial = tf.compat.v1.random.multinomial
@@ -89,6 +90,7 @@
8990
is_inf = tf.math.is_inf
9091
floormod = tf.floormod
9192
matrix_diag_part = tf.compat.v1.matrix_diag_part
93+
fake_quant_with_min_max_args = tf.compat.v1.quantization.fake_quant_with_min_max_args
9294
else:
9395
conv2d_backprop_input = tf.nn.conv2d_backprop_input
9496
multinomial = tf.multinomial
@@ -3353,6 +3355,42 @@ def func(base_matrix, diag, k):
33533355

33543356
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val, _INPUT1: diag_val, _INPUT2: k_val})
33553357

3358+
@check_opset_min_version(10)
3359+
@check_tf_min_version("1.14")
3360+
def test_fakequant_with_min_max(self):
3361+
def func(x):
3362+
ret = fake_quant_with_min_max_args(
3363+
x, min=-1024, max=1023, num_bits=8, narrow_range=False, name=None)
3364+
return tf.identity(ret, name=_TFOUTPUT)
3365+
3366+
x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024.
3367+
x_val0 = np.abs(x_val)
3368+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val0}, rtol=1e-6, atol=1e-4)
3369+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)
3370+
3371+
x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024
3372+
x_val[0, 0] = -1024
3373+
x_val[0, 1] = -1023
3374+
x_val[0, 2] = 1024
3375+
x_val[1, 0] = 1023
3376+
x_val[1, 1] = 1025
3377+
x_val[1, 2] = -1025
3378+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)
3379+
3380+
@check_opset_min_version(10)
3381+
@check_tf_min_version("1.14")
3382+
def test_fakequant_with_min_max_same_sign(self):
3383+
def func_neg(x):
3384+
ret = fake_quant_with_min_max_args(
3385+
x, min=-1024*3, max=-1024, num_bits=8, narrow_range=False, name=None)
3386+
return tf.identity(ret, name=_TFOUTPUT)
3387+
3388+
x_val = np.random.random(size=[4, 3]).astype(np.float32) * 2048. - 1024 * 3.
3389+
try:
3390+
self._run_test_case(func_neg, [_OUTPUT], {_INPUT: x_val}, rtol=1e-6, atol=1e-4)
3391+
except ValueError:
3392+
pass
3393+
33563394
@check_opset_min_version(9, "atan2")
33573395
def test_atan2(self):
33583396
# Test all possible pairs of pos, neg, zero for x and y.

tf2onnx/onnx_opset/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# Licensed under the MIT license.
33
"""tf2onnx.onnx_opset module"""
44

5-
from . import common, controlflow, generator, logical, math, misc, nn, reduction, rnn, tensor, traditionalml
5+
from . import common, controlflow, generator, logical, math, misc, nn, quantize, reduction, rnn, tensor, traditionalml

tf2onnx/onnx_opset/quantize.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tensor
6+
"""
7+
8+
from __future__ import division
9+
from __future__ import print_function
10+
from __future__ import unicode_literals
11+
12+
import logging
13+
14+
import numpy as np
15+
from onnx.onnx_pb import TensorProto
16+
17+
from tf2onnx import utils
18+
from tf2onnx.handler import tf_op
19+
from tf2onnx.utils import make_sure
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
# pylint: disable=unused-argument,missing-docstring,unused-variable,pointless-string-statement,invalid-name
25+
26+
27+
@tf_op("FakeQuantWithMinMaxArgs")
28+
class FakeQuantWithMinMaxArgs:
29+
# see https://www.tensorflow.org/api_docs/cc/class/tensorflow/ops/fake-quant-with-min-max-args
30+
@classmethod
31+
def version_10(cls, ctx, node, **kwargs):
32+
# hack to make up for the missing onnx pack op
33+
amin = node.get_attr("min").f
34+
amax = node.get_attr("max").f
35+
narrow_range = node.get_attr("narrow_range").i
36+
num_bits = node.get_attr("num_bits").i
37+
38+
make_sure(
39+
not narrow_range,
40+
"Unable to convert node FakeQuantWithMinMaxArgs with narrow_range=%r",
41+
narrow_range)
42+
make_sure(
43+
num_bits == 8,
44+
"Unable to convert node FakeQuantWithMinMaxArgs with "
45+
"num_bits=%r", num_bits)
46+
47+
scale = (amax - amin) / (2 ** num_bits - 1)
48+
min_adj = np.around(amin / scale)
49+
50+
dtype = ctx.get_dtype(node.input[0])
51+
shape = ctx.get_shape(node.input[0])
52+
axis = 1
53+
idtype = TensorProto.UINT8
54+
55+
pb_scale = ctx.make_const(
56+
utils.make_name("{}_scaley".format(node.name)),
57+
np.array(scale, dtype=np.float32))
58+
zero = np.array(-min_adj, dtype=np.uint8)
59+
make_sure(
60+
zero == -min_adj,
61+
"Cannot convert FakeQuantWithMinMaxArgs with "
62+
"min=%r max=%r numbits=%r because zero_scale=%r "
63+
"is outside uint8 boundary",
64+
amin, amax, num_bits, -min_adj)
65+
zero_point = ctx.make_const(
66+
utils.make_name("{}_zpy".format(node.name)), zero)
67+
68+
new_node = ctx.make_node(
69+
"QuantizeLinear", [node.input[0], pb_scale.name, zero_point.name],
70+
op_name_scope=node.name, attr={"axis": axis},
71+
shapes=[shape], dtypes=[idtype])
72+
output_name = new_node.output[0]
73+
node.input[0] = output_name
74+
75+
ctx.remove_node(node.name)
76+
77+
last_node = ctx.make_node(
78+
"DequantizeLinear", [new_node.output[0], pb_scale.name, zero_point.name],
79+
op_name_scope=node.name, attr={"axis": axis},
80+
shapes=[shape], dtypes=[dtype])
81+
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], last_node.output[0])

0 commit comments

Comments
 (0)