Skip to content

Commit d91dba2

Browse files
authored
Merge pull request #617 from lucienwang1009/check_shape
check shape for unittests
2 parents cbb3538 + 4f18624 commit d91dba2

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

tests/backend_test_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def run_backend(self, g, outputs, input_dict):
7575
return y
7676

7777
def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5,
78-
convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=False,
78+
convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=True,
7979
check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None):
8080
# optional - passed to process_tf_graph
8181
if process_args is None:
@@ -131,6 +131,8 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
131131
self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=atol)
132132
if check_dtype:
133133
self.assertEqual(expected_val.dtype, actual_val.dtype)
134+
# why need shape checke: issue when compare [] with scalar
135+
# https://github.com/numpy/numpy/issues/11071
134136
if check_shape:
135137
self.assertEqual(expected_val.shape, actual_val.shape)
136138

tests/test_backend.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,15 +1167,14 @@ def _test_range_non_const(self, extra_opset=None):
11671167
or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain))
11681168
tf.reset_default_graph()
11691169

1170-
# disable this case for ms domain due to onnxruntime range-1 issue
1171-
# https://github.com/Microsoft/onnxruntime/issues/730
1172-
if not (extra_opset and extra_opset.domain == constants.MICROSOFT_DOMAIN):
1173-
x = tf.range(3.0, 3.0, 5)
1174-
_ = tf.identity(x, name=_TFOUTPUT)
1175-
g = self._run_test_case([_OUTPUT], {}, process_args=process_args)
1176-
self.assertTrue(extra_opset is None
1177-
or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain))
1178-
tf.reset_default_graph()
1170+
# disable this case due to onnxruntime loop issue
1171+
# https://github.com/microsoft/onnxruntime/issues/1272
1172+
# x = tf.range(3.0, 3.0, 5)
1173+
# _ = tf.identity(x, name=_TFOUTPUT)
1174+
# g = self._run_test_case([_OUTPUT], {}, process_args=process_args)
1175+
# self.assertTrue(extra_opset is None
1176+
# or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain))
1177+
# tf.reset_default_graph()
11791178

11801179
delta_val = np.array(1.5, dtype=np.float32)
11811180
delta = tf.placeholder(tf.float32, shape=(), name=_TFINPUT)

0 commit comments

Comments
 (0)