Skip to content

Commit 1996a56

Browse files
committed
Remove assertion and add support for TF2
1 parent 4eae48f commit 1996a56

File tree

2 files changed

+3
-4
lines changed

2 files changed

+3
-4
lines changed

tests/test_backend.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
fused_batch_norm = tf.compat.v1.nn.fused_batch_norm
6363
dropout = tf.compat.v1.nn.dropout
6464
resize_nearest_neighbor = tf.compat.v1.image.resize_nearest_neighbor
65+
quantize_and_dequantize = tf.quantization.quantize_and_dequantize
6566
resize_bilinear = tf.compat.v1.image.resize_bilinear
6667
is_nan = tf.math.is_nan
6768
is_inf = tf.math.is_inf
@@ -1918,18 +1919,16 @@ def graph_validator(g):
19181919
self._run_test_case(func_fusedbn, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05, graph_validator=graph_validator)
19191920

19201921
@check_tf_min_version("1.15")
1921-
@skip_tf2()
19221922
@check_opset_min_version(10, "quantize_and_dequantize")
19231923
def test_qdq_unsigned_input(self):
19241924
x_shape = [3, 3, 2]
19251925
x_val = np.arange(1, 1+np.prod(x_shape)).astype("float32").reshape(x_shape)
19261926
def func(x):
1927-
x_ = quantize_and_dequantize(x, 1.0, 6.0, signed_input=False, narrow_range=False, range_given=True)
1927+
x_ = quantize_and_dequantize(x, 1.0, 6.0, signed_input=False, range_given=True)
19281928
return tf.identity(x_, name=_TFOUTPUT)
19291929
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
19301930

19311931
@check_tf_min_version("1.15")
1932-
@skip_tf2()
19331932
@check_opset_min_version(10, "quantize_and_dequantize")
19341933
def test_qdq_signed_input(self):
19351934
x_shape = [3, 3, 2]

tf2onnx/rewriter/quantization_ops_rewriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def create_qdq_nodes(g, match_results):
4545
else:
4646
scale = scale_from_max_side
4747

48-
assert scale > 0
48+
utils.make_sure(scale > 0, "Quantize/Dequantize scale must be greater than zero")
4949

5050
if signed_input:
5151
zero_point = np.int8(0)

0 commit comments

Comments
 (0)