Skip to content

Commit 8dd3aee

Browse files
committed
update PR
1 parent eca8647 commit 8dd3aee

File tree

4 files changed

+77
-171
lines changed

4 files changed

+77
-171
lines changed

tests/test_backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
is_inf = tf.math.is_inf
7070
floormod = tf.math.floormod
7171
matrix_diag_part = tf.compat.v1.matrix_diag_part
72+
fake_quant_with_min_max_args = tf.quantization.fake_quant_with_min_max_args
7273
elif LooseVersion(tf.__version__) >= "1.13":
7374
conv2d_backprop_input = tf.compat.v1.nn.conv2d_backprop_input
7475
multinomial = tf.compat.v1.random.multinomial
@@ -88,6 +89,7 @@
8889
is_inf = tf.math.is_inf
8990
floormod = tf.floormod
9091
matrix_diag_part = tf.compat.v1.matrix_diag_part
92+
fake_quant_with_min_max_args = tf.compat.v1.quantization.fake_quant_with_min_max_args
9193
else:
9294
conv2d_backprop_input = tf.nn.conv2d_backprop_input
9395
multinomial = tf.multinomial
@@ -3352,6 +3354,19 @@ def func(base_matrix, diag, k):
33523354

33533355
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val, _INPUT1: diag_val, _INPUT2: k_val})
33543356

3357+
@check_opset_min_version(10)
3358+
@check_tf_min_version("1.14")
3359+
def test_fakequant_with_min_max(self):
3360+
x_val = np.random.random(size=[4, 5]).astype(np.float32) * 2048. - 1024.
3361+
def func(x):
3362+
ret = fake_quant_with_min_max_args(
3363+
x, min=-1024, max=1024, num_bits=8, narrow_range=False, name=None)
3364+
return tf.identity(ret, name=_TFOUTPUT)
3365+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3366+
33553367

33563368
if __name__ == '__main__':
3369+
cl = BackendTests()
3370+
cl.setUp()
3371+
cl.test_fakequant_with_min_max()
33573372
unittest_main()

tests/test_quantization.py

Lines changed: 0 additions & 170 deletions
This file was deleted.

tf2onnx/onnx_opset/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# Licensed under the MIT license.
33
"""tf2onnx.onnx_opset module"""
44

5-
from . import common, controlflow, generator, logical, math, misc, nn, reduction, rnn, tensor, traditionalml
5+
from . import common, controlflow, generator, logical, math, misc, nn, quantize, reduction, rnn, tensor, traditionalml

tf2onnx/onnx_opset/quantize.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tensor
6+
"""
7+
8+
from __future__ import division
9+
from __future__ import print_function
10+
from __future__ import unicode_literals
11+
12+
import logging
13+
import sys
14+
15+
import numpy as np
16+
from onnx import onnx_pb
17+
from onnx.onnx_pb import TensorProto
18+
19+
from tf2onnx import constants, utils
20+
from tf2onnx.graph_builder import GraphBuilder
21+
from tf2onnx.handler import tf_op
22+
from tf2onnx.onnx_opset import nn, math
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
# pylint: disable=unused-argument,missing-docstring,unused-variable,pointless-string-statement,invalid-name
28+
29+
30+
@tf_op("FakeQuantWithMinMaxVars")
31+
class FakeQuantWithMinMaxVars:
32+
@classmethod
33+
def version_11(cls, ctx, node, **kwargs):
34+
# hack to make up for the missing onnx pack op
35+
import pprint
36+
pprint.pprint(node)
37+
amin = node.get_attr("min").i
38+
if axis < 0:
39+
axis += len(ctx.get_shape(node.input[0])) + 1
40+
41+
inputs = []
42+
dtype = None
43+
# insert Unsqueeze on each input
44+
for i, n in enumerate(node.inputs):
45+
dtype = ctx.get_dtype(node.input[i])
46+
shape = ctx.get_shape(node.input[i])
47+
new_node = ctx.make_node("Unsqueeze", [node.input[i]], op_name_scope=node.name, attr={"axes": [axis]},
48+
shapes=[shape], dtypes=[dtype])
49+
output_name = new_node.output[0]
50+
node.input[i] = output_name
51+
inputs.append(output_name)
52+
53+
shapes = node.output_shapes
54+
dtypes = node.output_dtypes
55+
ctx.remove_node(node.name)
56+
# concat all unqueezes
57+
concat = ctx.make_node("Concat", inputs, op_name_scope=node.name, attr={"axis": axis},
58+
shapes=shapes, dtypes=dtypes)
59+
ctx.replace_all_inputs(ctx.get_nodes(), node.output[0], concat.output[0])
60+
61+

0 commit comments

Comments
 (0)