Skip to content

Commit f1617ac

Browse files
Implement DivNoNan conversion (#1499)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 9c7b56c commit f1617ac

File tree

6 files changed

+47
-18
lines changed

6 files changed

+47
-18
lines changed

tests/backend_test_base.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from __future__ import unicode_literals
1010

1111
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,import-outside-toplevel
12-
# pylint: disable=wrong-import-position
12+
# pylint: disable=wrong-import-position,invalid-unary-operand-type
1313

1414
import logging
1515
import os
@@ -106,7 +106,8 @@ def run_backend(self, g, outputs, input_dict, large_model=False, postfix=""):
106106
raise ValueError("unknown backend")
107107
return y
108108

109-
def assert_results_equal(self, expected, actual, rtol, atol, check_value=True, check_shape=True, check_dtype=True):
109+
def assert_results_equal(self, expected, actual, rtol, atol, mtol=None,
110+
check_value=True, check_shape=True, check_dtype=True):
110111
for expected_val, actual_val in zip(expected, actual):
111112
if check_value:
112113
if expected_val.dtype == np.object:
@@ -115,6 +116,11 @@ def assert_results_equal(self, expected, actual, rtol, atol, check_value=True, c
115116
expected_val_str = decode(expected_val)
116117
self.assertAllEqual(expected_val_str, actual_val)
117118
else:
119+
if mtol is not None:
120+
expected_val = np.minimum(expected_val, mtol)
121+
expected_val = np.maximum(expected_val, -mtol)
122+
actual_val = np.minimum(actual_val, mtol)
123+
actual_val = np.maximum(actual_val, -mtol)
118124
self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=atol)
119125
if check_dtype:
120126
self.assertEqual(expected_val.dtype, actual_val.dtype)
@@ -285,10 +291,10 @@ def get_shape(info):
285291
self.assertEqual(onnx_shape, tf2onnx_shape)
286292
self.assertEqual(info.type.tensor_type.elem_type, graph.get_dtype(info.name))
287293

288-
def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5,
289-
convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=True,
290-
check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None, as_session=False,
291-
large_model=False, premade_placeholders=False):
294+
def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port,
295+
rtol=1e-07, atol=1e-5, mtol=None, convert_var_to_const=True, constant_fold=True,
296+
check_value=True, check_shape=True, check_dtype=True, process_args=None, onnx_feed_dict=None,
297+
graph_validator=None, as_session=False, large_model=False, premade_placeholders=False):
292298
test_tf = not self.config.skip_tf_tests
293299
test_tflite = not self.config.skip_tflite_tests
294300
run_tfl_consistency_test = test_tf and test_tflite and self.config.run_tfl_consistency_test
@@ -330,19 +336,19 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
330336
g = optimizer.optimize_graph(g, catch_errors=False)
331337
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model)
332338

333-
self.assert_results_equal(expected, actual, rtol, atol, check_value, check_shape, check_dtype)
339+
self.assert_results_equal(expected, actual, rtol, atol, mtol, check_value, check_shape, check_dtype)
334340
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
335341

336342
if graph_validator:
337343
self.assertTrue(graph_validator(g))
338344

339345
if test_tflite:
340-
tfl_results, tfl_outputs = self.run_tflite(tflite_path, feed_dict)
341-
test_tflite = tfl_results is not None
346+
tfl_res, tfl_outputs = self.run_tflite(tflite_path, feed_dict)
347+
test_tflite = tfl_res is not None
342348

343349
if test_tflite:
344350
if run_tfl_consistency_test:
345-
self.assert_results_equal(expected, tfl_results, rtol, atol, check_value, check_shape, check_dtype)
351+
self.assert_results_equal(expected, tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)
346352

347353
tfl_process_args = process_args.copy()
348354
if 'inputs_as_nchw' in tfl_process_args:
@@ -358,9 +364,9 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
358364
**tfl_process_args)
359365
g = optimizer.optimize_graph(g)
360366
onnx_feed_dict_without_port = {k.split(':')[0]: v for k, v in onnx_feed_dict.items()}
361-
onnx_from_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, postfix="_from_tflite")
367+
onnx_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, postfix="_from_tflite")
362368

363-
self.assert_results_equal(tfl_results, onnx_from_tfl_res, rtol, atol, check_value, check_shape, check_dtype)
369+
self.assert_results_equal(tfl_res, onnx_tfl_res, rtol, atol, mtol, check_value, check_shape, check_dtype)
364370
self.assert_shapes_correct(g, self.config.allow_missing_shapes, not self.config.skip_onnx_checker)
365371

366372
if graph_validator:

tests/test_backend.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,17 @@ def func(x):
983983
return tf.identity(x_, name=_TFOUTPUT)
984984
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
985985

986+
@check_tf_min_version("1.14")
987+
@check_opset_min_version(11, "float equality")
988+
def test_div_no_nan(self):
989+
x_val = np.array([1.0, 2.0, -3.0, -4.0, 5.0, 0.0, float("nan"), float("-inf"), float("inf")], dtype=np.float32)
990+
y_val = np.array([1.0, 0.5, 0.0, -4.0, 0.0, 0.0, 0.0, 2.0, 0.0], dtype=np.float32)
991+
def func(x, y):
992+
x_ = tf.math.divide_no_nan(x, y)
993+
return tf.identity(x_, name=_TFOUTPUT)
994+
# TFLite expresses infinity as a value > 1e38
995+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val}, mtol=1e38)
996+
986997
@check_onnxruntime_incompatibility("Exp")
987998
def test_exp(self):
988999
x_val = np.array([1.0, 2.0, -3.0, -4.0], dtype=np.float32).reshape((2, 2))

tf2onnx/onnx_opset/math.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,18 @@ def version_7(cls, ctx, node, **kwargs):
321321
pass
322322

323323

324+
@tf_op("DivNoNan")
325+
class DivNoNan:
326+
@classmethod
327+
def version_9(cls, ctx, node, **kwargs):
328+
node.type = "Div"
329+
np_dtype = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[1]))
330+
zero_const = ctx.make_const(utils.make_name("const_zero"), np.array(0, np_dtype)).output[0]
331+
is_zero = ctx.make_node("Equal", [node.input[1], zero_const]).output[0]
332+
where_node = ctx.make_node("Where", [is_zero, zero_const, node.output[0]])
333+
ctx.insert_node_on_output(where_node, node.output[0])
334+
335+
324336
@tf_op("LRN")
325337
class LRN:
326338
@classmethod

tf2onnx/tflite_rewriters/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44

55
from tf2onnx.tflite_rewriters.tfl_scan_output_rewriter import rewrite_tfl_scan_outputs
66
from tf2onnx.tflite_rewriters.tfl_qdq_rewriter import rewrite_tfl_qdq
7-
from tf2onnx.tflite_rewriters.tfl_select_zero_mul_rewriter import rewrite_tfl_select_zero_mul
7+
from tf2onnx.tflite_rewriters.tfl_select_zero_rewriter import rewrite_tfl_select_zero
88

99
__all__ = [
1010
"rewrite_tfl_scan_outputs",
1111
"rewrite_tfl_qdq",
12-
"rewrite_tfl_select_zero_mul",
12+
"rewrite_tfl_select_zero",
1313
]

tf2onnx/tflite_rewriters/tfl_select_zero_mul_rewriter.py renamed to tf2onnx/tflite_rewriters/tfl_select_zero_rewriter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,22 @@
22

33

44
"""
5-
tf2onnx.tflite_rewriters.tfl_select_zero_mul_rewriter - TFLite has a pattern to remove NaN when multiplying by 0
5+
tf2onnx.tflite_rewriters.tfl_select_zero_rewriter - TFLite has a pattern to remove NaN when multiplying/dividing by 0
66
"""
77
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
88

99

1010
# pylint: disable=missing-docstring,unused-argument
1111

12-
def rewrite_tfl_select_zero_mul(g, ops):
12+
def rewrite_tfl_select_zero(g, ops):
1313
pattern0 = \
1414
OpTypePattern('TFL_SELECT_V2', name='select', inputs=[
1515
OpTypePattern('TFL_EQUAL', name='equal', inputs=[
1616
OpTypePattern('Const|ConstV2', name='const_eq'),
1717
OpTypePattern('*', name='term_eq'),
1818
], allow_reorder=True),
1919
OpTypePattern('Const|ConstV2', name='const_select'),
20-
OpTypePattern('TFL_MUL', name='mul', inputs=[
20+
OpTypePattern('TFL_MUL|TFL_DIV', name='mul', inputs=[
2121
OpTypePattern('*', name='term_mul1'),
2222
OpTypePattern('*', name='term_mul2'),
2323
]),

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def process_parsed_graph(g, custom_op_handlers, inputs_as_nchw, continue_on_erro
544544
if dequantize:
545545
tfl_rewriters.append(rewrite_tfl_qdq)
546546
tfl_rewriters.append(rewrite_tfl_scan_outputs)
547-
tfl_rewriters.append(rewrite_tfl_select_zero_mul)
547+
tfl_rewriters.append(rewrite_tfl_select_zero)
548548
run_rewriters(g, tfl_rewriters, continue_on_error)
549549
tfl_ops_mapping = handler.tfl_op.create_tfl_to_tf_mapping()
550550
_, _, exceptions = tensorflow_onnx_mapping(g, tfl_ops_mapping, is_tflite=True, dequantize=False)

0 commit comments

Comments
 (0)