Skip to content

Commit 7798a2f

Browse files
committed
add leakyrelu rewriter and related test
1 parent ed9cb4f commit 7798a2f

File tree

4 files changed

+67
-6
lines changed

4 files changed

+67
-6
lines changed

tests/test_backend.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -721,11 +721,13 @@ def test_relu(self):
721721
@skip_caffe2_backend("fails on caffe2 with dim issue")
722722
@check_onnxruntime_incompatibility("Mul")
723723
def test_leaky_relu(self):
724-
x_val = np.array([0.5, 1.0, -0.5, -1.0], dtype=np.float32).reshape((2, 2))
725-
x = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
726-
x_ = tf.nn.leaky_relu(x)
727-
_ = tf.identity(x_, name=_TFOUTPUT)
728-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
724+
for alpha in [0.1, -0.1]:
725+
x_val = np.array([0.5, 1.0, -0.5, -1.0], dtype=np.float32).reshape((2, 2))
726+
x = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
727+
x_ = tf.nn.leaky_relu(x, alpha)
728+
_ = tf.identity(x_, name=_TFOUTPUT)
729+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
730+
tf.reset_default_graph()
729731

730732
@check_onnxruntime_incompatibility("Elu")
731733
def test_elu(self):

tf2onnx/rewriter/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"cond_rewriter",
1111
"custom_rnn_rewriter",
1212
"gru_rewriter",
13+
"leakyrelu_rewriter",
1314
"loop_rewriter",
1415
"loop_rewriter_base",
1516
"lstm_rewriter",
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewrite - rewrite tensorflow subgraph to onnx leakyrelu op
6+
"""
7+
8+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
9+
10+
11+
# pylint: disable=missing-docstring
12+
13+
14+
def rewrite_leakyrelu(g, ops):
15+
if g.opset < 6:
16+
return ops
17+
18+
pattern = \
19+
OpTypePattern('Maximum', name='max', inputs=[
20+
OpTypePattern('Mul', name='mul', inputs=[
21+
OpTypePattern('Const', name='alpha'),
22+
OpTypePattern('*', name='mul_input'),
23+
]),
24+
OpTypePattern('*', name='max_input'),
25+
])
26+
27+
matcher = GraphMatcher(pattern, allow_reorder=True)
28+
match_results = list(matcher.match_ops(ops))
29+
for match in match_results:
30+
max_node = match.get_op('max')
31+
max_input_node = match.get_op('max_input')
32+
mul_node = match.get_op("mul")
33+
mul_input_node = match.get_op('mul_input')
34+
35+
max_input_edge_name = _find_edges_name_btw_nodes(max_input_node, max_node)
36+
mul_input_edge_name = _find_edges_name_btw_nodes(mul_input_node, mul_node)
37+
if max_input_edge_name == mul_input_edge_name:
38+
alpha = match.get_op("alpha").get_tensor_value()
39+
if alpha >= 1:
40+
continue
41+
leakyrelu = g.make_node("LeakyRelu", inputs=max_input_edge_name, attr={"alpha": alpha},
42+
shapes=[g.get_shape(max_node.output[0])], dtypes=[g.get_dtype(max_node.output[0])])
43+
ops.remove(max_node)
44+
ops.remove(mul_node)
45+
ops.append(leakyrelu)
46+
g.replace_all_inputs(ops, max_node.output[0], leakyrelu.output[0])
47+
48+
return ops
49+
50+
51+
def _find_edges_name_btw_nodes(sender, sinker):
52+
res = []
53+
for sinker_end in sinker.input:
54+
for sender_end in sender.output:
55+
if sinker_end == sender_end:
56+
res.append(sinker_end)
57+
return res

tf2onnx/tfonnx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2727
from tf2onnx.rewriter.cond_rewriter import rewrite_cond
2828
from tf2onnx.rewriter.random_uniform import rewrite_random_uniform, rewrite_random_uniform_fold_const
29+
from tf2onnx.rewriter.leakyrelu_rewriter import rewrite_leakyrelu
2930
from tf2onnx.rewriter.rnn import rewrite_bi_direction_gru
3031
from tf2onnx.rewriter.rnn import rewrite_custom_rnn_cell
3132
from tf2onnx.rewriter.rnn import rewrite_generic_loop
@@ -2475,7 +2476,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
24752476
# bi-directional re-writer should be placed after single directional re-writer
24762477
rewriters = [rewrite_transpose, rewrite_flatten,
24772478
rewrite_random_uniform, rewrite_random_uniform_fold_const,
2478-
rewrite_random_normal, rewrite_dropout,
2479+
rewrite_random_normal, rewrite_dropout, rewrite_leakyrelu,
24792480
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
24802481
rewrite_single_direction_gru, rewrite_single_direction_grublock,
24812482
rewrite_bi_direction_gru, rewrite_logical_compare_with_equal,

0 commit comments

Comments
 (0)