Skip to content

Commit 5e98390

Browse files
authored
Merge pull request #458 from mindest/bug_fix_leaky_relu
fix bug in leaky_relu for tf_version <= 1.5
2 parents 3a9f01f + 56595ec commit 5e98390

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

tests/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tf2onnx import constants, logging, utils
1515

1616
__all__ = ["TestConfig", "get_test_config", "unittest_main", "check_onnxruntime_backend",
17-
"check_tf_min_version", "skip_tf_versions", "check_onnxruntime_min_version",
17+
"check_tf_min_version", "check_tf_max_version", "skip_tf_versions", "check_onnxruntime_min_version",
1818
"check_opset_min_version", "check_target", "skip_caffe2_backend", "skip_onnxruntime_backend",
1919
"skip_opset", "check_onnxruntime_incompatibility", "validate_const_node",
2020
"group_nodes_by_type", "test_ms_domain", "check_node_domain"]

tests/test_backend.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -799,8 +799,10 @@ def test_relu(self):
799799

800800
@skip_caffe2_backend("fails on caffe2 with dim issue")
801801
@check_onnxruntime_incompatibility("Mul")
802-
def test_leaky_relu(self):
803-
x_types = [np.float32, np.int32, np.int64]
802+
@check_tf_min_version("1.6")
803+
def test_leaky_relu_int(self):
804+
# starting from tf 1.6, leaky_relu supports `feature` x of int type
805+
x_types = [np.int32, np.int64]
804806
for x_type in x_types:
805807
x_val = 1000 * np.random.random_sample([1000, 100]).astype(x_type)
806808
for alpha in [0.1, -0.1, 1.0, -1.0]:
@@ -810,6 +812,17 @@ def test_leaky_relu(self):
810812
self._run_test_case([_OUTPUT], {_INPUT: x_val})
811813
tf.reset_default_graph()
812814

815+
@skip_caffe2_backend("fails on caffe2 with dim issue")
816+
@check_onnxruntime_incompatibility("Mul")
817+
def test_leaky_relu_float(self):
818+
x_val = 1000 * np.random.random_sample([1000, 100]).astype(np.float32)
819+
for alpha in [0.1, -0.1, 1.0, -1.0]:
820+
x = tf.placeholder(x_val.dtype, [None] * x_val.ndim, name=_TFINPUT)
821+
x_ = tf.nn.leaky_relu(x, alpha)
822+
_ = tf.identity(x_, name=_TFOUTPUT)
823+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
824+
tf.reset_default_graph()
825+
813826
@check_onnxruntime_incompatibility("Elu")
814827
def test_elu(self):
815828
x_val = np.array([0.5, 1.0, -0.5, -1.0], dtype=np.float32).reshape((2, 2))

0 commit comments

Comments
 (0)