|
16 | 16 | from backend_test_base import Tf2OnnxBackendTestBase
|
17 | 17 | # pylint reports unused-wildcard-import which is false positive, __all__ is defined in common
|
18 | 18 | from common import * # pylint: disable=wildcard-import,unused-wildcard-import
|
| 19 | +from tf2onnx import constants |
19 | 20 |
|
20 | 21 | # pylint: disable=missing-docstring,invalid-name,unused-argument
|
21 | 22 |
|
@@ -905,60 +906,95 @@ def test_sqrt(self):
|
905 | 906 | _ = tf.identity(x_, name=_TFOUTPUT)
|
906 | 907 | self._run_test_case([_OUTPUT], {_INPUT: x_val})
|
907 | 908 |
|
908 |
| - @check_opset_min_version(7, "cast") |
909 |
| - def test_range_const(self): |
| 909 | + def _test_range_const(self, extra_opset=None): |
| 910 | + process_args = {} |
| 911 | + if extra_opset is not None: |
| 912 | + process_args["extra_opset"] = [extra_opset] |
| 913 | + |
910 | 914 | x = tf.range(5)
|
911 | 915 | _ = tf.identity(x, name=_TFOUTPUT)
|
912 |
| - self._run_test_case([_OUTPUT], {}) |
| 916 | + self._run_test_case([_OUTPUT], {}, process_args=process_args) |
913 | 917 | tf.reset_default_graph()
|
914 | 918 |
|
915 | 919 | x = tf.range(3, 3, 5)
|
916 | 920 | _ = tf.identity(x, name=_TFOUTPUT)
|
917 |
| - self._run_test_case([_OUTPUT], {}) |
| 921 | + self._run_test_case([_OUTPUT], {}, process_args=process_args) |
918 | 922 | tf.reset_default_graph()
|
919 | 923 |
|
920 | 924 | x = tf.range(0, -5, -2)
|
921 | 925 | _ = tf.identity(x, name=_TFOUTPUT)
|
922 |
| - self._run_test_case([_OUTPUT], {}) |
| 926 | + self._run_test_case([_OUTPUT], {}, process_args=process_args) |
923 | 927 | tf.reset_default_graph()
|
924 | 928 |
|
925 | 929 | x = tf.range(-5.0, 5.0, 1.5)
|
926 | 930 | _ = tf.identity(x, name=_TFOUTPUT)
|
927 |
| - self._run_test_case([_OUTPUT], {}) |
| 931 | + self._run_test_case([_OUTPUT], {}, process_args=process_args) |
928 | 932 | tf.reset_default_graph()
|
929 | 933 |
|
930 | 934 | x = tf.range(2.5, 5.0, 10.0)
|
931 | 935 | _ = tf.identity(x, name=_TFOUTPUT)
|
932 |
| - self._run_test_case([_OUTPUT], {}) |
| 936 | + self._run_test_case([_OUTPUT], {}, process_args=process_args) |
| 937 | + |
| 938 | + def _test_range_non_const(self, extra_opset=None): |
| 939 | + process_args = {} |
| 940 | + if extra_opset is not None: |
| 941 | + process_args["extra_opset"] = [extra_opset] |
933 | 942 |
|
934 |
| - def test_range_non_const(self): |
935 | 943 | x = tf.range(5.0)
|
936 | 944 | _ = tf.identity(x, name=_TFOUTPUT)
|
937 |
| - self._run_test_case([_OUTPUT], {}) |
| 945 | + g = self._run_test_case([_OUTPUT], {}, process_args=process_args) |
| 946 | + self.assertTrue(extra_opset is None |
| 947 | + or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain)) |
938 | 948 | tf.reset_default_graph()
|
939 | 949 |
|
940 | 950 | x = tf.range(0, -5.0, -2)
|
941 | 951 | _ = tf.identity(x, name=_TFOUTPUT)
|
942 |
| - self._run_test_case([_OUTPUT], {}) |
| 952 | + g = self._run_test_case([_OUTPUT], {}, process_args=process_args) |
| 953 | + self.assertTrue(extra_opset is None |
| 954 | + or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain)) |
943 | 955 | tf.reset_default_graph()
|
944 | 956 |
|
945 |
| - x = tf.range(3.0, 3.0, 5) |
946 |
| - _ = tf.identity(x, name=_TFOUTPUT) |
947 |
| - self._run_test_case([_OUTPUT], {}) |
948 |
| - tf.reset_default_graph() |
| 957 | + # disable this case for ms domain due to onnxruntime range-1 issue |
| 958 | + # https://github.com/Microsoft/onnxruntime/issues/730 |
| 959 | + if not (extra_opset and extra_opset.domain == constants.MICROSOFT_DOMAIN): |
| 960 | + x = tf.range(3.0, 3.0, 5) |
| 961 | + _ = tf.identity(x, name=_TFOUTPUT) |
| 962 | + g = self._run_test_case([_OUTPUT], {}, process_args=process_args) |
| 963 | + self.assertTrue(extra_opset is None |
| 964 | + or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain)) |
| 965 | + tf.reset_default_graph() |
949 | 966 |
|
950 | 967 | delta_val = np.array(1.5, dtype=np.float32)
|
951 | 968 | delta = tf.placeholder(tf.float32, shape=(), name=_TFINPUT)
|
952 | 969 | x = tf.range(-5.0, 5.0, delta)
|
953 | 970 | _ = tf.identity(x, name=_TFOUTPUT)
|
954 |
| - self._run_test_case([_OUTPUT], {_INPUT: delta_val}) |
| 971 | + g = self._run_test_case([_OUTPUT], {_INPUT: delta_val}, process_args=process_args) |
| 972 | + self.assertTrue(extra_opset is None |
| 973 | + or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain)) |
955 | 974 | tf.reset_default_graph()
|
956 | 975 |
|
957 | 976 | start_val = np.array(2.5, dtype=np.float32)
|
958 | 977 | start = tf.placeholder(tf.float32, shape=(), name=_TFINPUT)
|
959 | 978 | x = tf.range(start, 5.0, 10.0)
|
960 | 979 | _ = tf.identity(x, name=_TFOUTPUT)
|
961 |
| - self._run_test_case([_OUTPUT], {_INPUT: start_val}) |
| 980 | + g = self._run_test_case([_OUTPUT], {_INPUT: start_val}, process_args=process_args) |
| 981 | + self.assertTrue(extra_opset is None |
| 982 | + or check_node_domain(group_nodes_by_type(g)["Range"][0], extra_opset.domain)) |
| 983 | + |
| 984 | + @check_opset_min_version(7, "cast") |
| 985 | + def test_range_const(self): |
| 986 | + self._test_range_const() |
| 987 | + |
| 988 | + def test_range_non_const(self): |
| 989 | + self._test_range_non_const() |
| 990 | + |
| 991 | + @test_ms_domain() |
| 992 | + def test_ms_range_const(self, extra_opset): |
| 993 | + self._test_range_const(extra_opset) |
| 994 | + |
| 995 | + @test_ms_domain() |
| 996 | + def test_ms_range_non_const(self, extra_opset): |
| 997 | + self._test_range_non_const(extra_opset) |
962 | 998 |
|
963 | 999 | @check_onnxruntime_incompatibility("Sqrt")
|
964 | 1000 | def test_rsqrt(self):
|
|
0 commit comments