Skip to content

Commit 758b876

Browse files
committed
fix bug in leaky_relu for tf_version <= 1.5
1 parent 3a9f01f commit 758b876

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

tests/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tf2onnx import constants, logging, utils
1515

1616
__all__ = ["TestConfig", "get_test_config", "unittest_main", "check_onnxruntime_backend",
17-
"check_tf_min_version", "skip_tf_versions", "check_onnxruntime_min_version",
17+
"check_tf_min_version", "check_tf_max_version", "skip_tf_versions", "check_onnxruntime_min_version",
1818
"check_opset_min_version", "check_target", "skip_caffe2_backend", "skip_onnxruntime_backend",
1919
"skip_opset", "check_onnxruntime_incompatibility", "validate_const_node",
2020
"group_nodes_by_type", "test_ms_domain", "check_node_domain"]

tests/test_backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,7 +799,9 @@ def test_relu(self):
799799

800800
@skip_caffe2_backend("fails on caffe2 with dim issue")
801801
@check_onnxruntime_incompatibility("Mul")
802+
@check_tf_min_version("1.6")
802803
def test_leaky_relu(self):
804+
# starting from tf 1.6, leaky_relu supports `feature` x of int type
803805
x_types = [np.float32, np.int32, np.int64]
804806
for x_type in x_types:
805807
x_val = 1000 * np.random.random_sample([1000, 100]).astype(x_type)
@@ -810,6 +812,19 @@ def test_leaky_relu(self):
810812
self._run_test_case([_OUTPUT], {_INPUT: x_val})
811813
tf.reset_default_graph()
812814

815+
@skip_caffe2_backend("fails on caffe2 with dim issue")
816+
@check_onnxruntime_incompatibility("Mul")
817+
@check_tf_max_version("1.5")
818+
def test_leaky_relu_old(self):
819+
# for tf_version <= 1.5, leaky_relu requires `feature` x to be of type `float32`
820+
x_val = 1000 * np.random.random_sample([1000, 100]).astype(np.float32)
821+
for alpha in [0.1, -0.1, 1.0, -1.0]:
822+
x = tf.placeholder(x_val.dtype, [None] * x_val.ndim, name=_TFINPUT)
823+
x_ = tf.nn.leaky_relu(x, alpha)
824+
_ = tf.identity(x_, name=_TFOUTPUT)
825+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
826+
tf.reset_default_graph()
827+
813828
@check_onnxruntime_incompatibility("Elu")
814829
def test_elu(self):
815830
x_val = np.array([0.5, 1.0, -0.5, -1.0], dtype=np.float32).reshape((2, 2))

0 commit comments

Comments
 (0)