Skip to content

Commit be29f5c

Browse files
committed
Fix pylint errors and add tensorflow version dependencies
1 parent 9f19751 commit be29f5c

File tree

2 files changed

+38
-30
lines changed

2 files changed

+38
-30
lines changed

tests/test_backend.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1916,7 +1916,8 @@ def graph_validator(g):
19161916
return True
19171917

19181918
self._run_test_case(func_fusedbn, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05, graph_validator=graph_validator)
1919-
1919+
1920+
@check_tf_min_version("1.15", "not supported in tf-2.0")
19201921
@check_opset_min_version(10, "quantize_and_dequantize")
19211922
def test_qdq_unsigned_input(self):
19221923
x_shape = [3, 3, 2]
@@ -1925,7 +1926,8 @@ def func(x):
19251926
x_ = quantize_and_dequantize(x, 1.0, 6.0, signed_input=False, narrow_range=False, range_given=True)
19261927
return tf.identity(x_, name=_TFOUTPUT)
19271928
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1928-
1929+
1930+
@check_tf_min_version("1.15", "not supported in tf-2.0")
19291931
@check_opset_min_version(10, "quantize_and_dequantize")
19301932
def test_qdq_signed_input(self):
19311933
x_shape = [3, 3, 2]
@@ -1934,7 +1936,7 @@ def func(x):
19341936
x_ = quantize_and_dequantize(x, -6.0, 6.0, signed_input=True, narrow_range=True, range_given=True)
19351937
return tf.identity(x_, name=_TFOUTPUT)
19361938
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
1937-
1939+
19381940
@skip_caffe2_backend()
19391941
@check_opset_min_version(7, "resize_nearest_neighbor")
19401942
def test_resize_nearest_neighbor(self):

tf2onnx/rewriter/quantization_ops_rewriter.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,17 @@
55
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV3 op
66
"""
77

8-
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
9-
108
import numpy as np
11-
import struct
12-
13-
14-
from tf2onnx.handler import tf_op
9+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
1510
from tf2onnx import utils
16-
from tf2onnx import constants
1711

1812
# pylint: disable=missing-docstring
1913

2014
def extract_numpy_array(node):
2115
return np.frombuffer(node.attr["value"].t.raw_data, dtype="float32")
2216

2317
def create_qdq_nodes(g, match_results):
24-
18+
2519
for match in match_results:
2620
qdq_node = match.get_op('output')
2721
qdq_node_output_dtype = g.get_dtype(qdq_node.output[0])
@@ -30,51 +24,63 @@ def create_qdq_nodes(g, match_results):
3024
# Get the attributes of qdq node
3125
narrow_range = qdq_node.attr['narrow_range'].i
3226
signed_input = qdq_node.attr['signed_input'].i
33-
34-
min_quantized, max_quantized = [-127,127]
27+
28+
min_quantized, max_quantized = [-127, 127]
3529
if not narrow_range and signed_input:
3630
min_quantized = -128
37-
31+
3832
if not signed_input:
39-
min_quantized, max_quantized = [0,255]
40-
33+
min_quantized, max_quantized = [0, 255]
34+
4135
# Get the min and max value of the inputs to QDQ op
4236
min_value = extract_numpy_array(qdq_node.inputs[1])
4337
max_value = extract_numpy_array(qdq_node.inputs[2])
44-
38+
4539
# Calculate scales from the min and max values
4640
scale_from_min_side = min_quantized/min_value if min_quantized*min_value > 0 else max_quantized
4741
scale_from_max_side = max_quantized/max_value if max_quantized*max_value > 0 else max_quantized
48-
42+
4943
if scale_from_min_side < scale_from_max_side:
5044
scale = scale_from_min_side
5145
else:
5246
scale = scale_from_max_side
53-
47+
5448
assert scale > 0
55-
49+
5650
if signed_input:
5751
zero_point = np.int8(0)
5852
else:
5953
zero_point = np.uint8(0)
60-
54+
6155
# Split it into QuantizeLinear and DequantizeLinear and remove the QDQ node reference
62-
y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val = 1/scale)
56+
y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val=1/scale)
6357
y_zero_point = g.make_const(name=utils.make_name("y_zero_point"), np_val=zero_point)
64-
quant_node = g.make_node(op_type = "QuantizeLinear", inputs=[qdq_node.input[0], y_quant_scale.output[0], y_zero_point.output[0]], shapes=[qdq_node_output_shape], dtypes=[qdq_node_output_dtype], name=utils.make_name("QuantLinearNode"))
58+
quant_node = g.make_node(op_type="QuantizeLinear",
59+
inputs=[qdq_node.input[0], y_quant_scale.output[0],
60+
y_zero_point.output[0]],
61+
shapes=[qdq_node_output_shape],
62+
dtypes=[qdq_node_output_dtype],
63+
name=utils.make_name("QuantLinearNode"))
64+
6565
g.set_shape(quant_node.output[0], qdq_node_output_shape)
66-
66+
6767
g.remove_node(qdq_node.name)
6868

69-
y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val = 1/scale)
69+
y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val=1/scale)
7070
y_inv_zero_point = g.make_const(name=utils.make_name("y_inv_zero_point"), np_val=zero_point)
71-
dequant_node = g.make_node(op_type = "DequantizeLinear", inputs=[quant_node.output[0], y_dequant_scale.output[0], y_inv_zero_point.output[0]], outputs = [qdq_node.output[0]], shapes=[qdq_node_output_shape], dtypes=[qdq_node_output_dtype], name=utils.make_name("DequantLinearNode"))
71+
dequant_node = g.make_node(op_type="DequantizeLinear",
72+
inputs=[quant_node.output[0], y_dequant_scale.output[0],
73+
y_inv_zero_point.output[0]],
74+
outputs=[qdq_node.output[0]],
75+
shapes=[qdq_node_output_shape],
76+
dtypes=[qdq_node_output_dtype],
77+
name=utils.make_name("DequantLinearNode"))
7278
g.set_shape(dequant_node.output[0], qdq_node_output_shape)
73-
79+
7480
return g.get_nodes()
7581

7682
def rewrite_quantize_and_dequantize(g, ops):
77-
83+
7884
pattern_for_qdq_v2 = \
7985
OpTypePattern('QuantizeAndDequantizeV2', name='output', inputs=[
8086
OpTypePattern("*"),
@@ -88,7 +94,7 @@ def rewrite_quantize_and_dequantize(g, ops):
8894
OpTypePattern(None),
8995
OpTypePattern(None),
9096
])
91-
97+
9298
# Match all the patterns for QDQ ops
9399
patterns = [pattern_for_qdq_v3, pattern_for_qdq_v2]
94100
match_results = []
@@ -97,4 +103,4 @@ def rewrite_quantize_and_dequantize(g, ops):
97103
results = list(matcher.match_ops(ops))
98104
match_results.extend(results)
99105

100-
return create_qdq_nodes(g, match_results)
106+
return create_qdq_nodes(g, match_results)

0 commit comments

Comments
 (0)