Skip to content

Commit d949ceb

Browse files
authored
Merge pull request #328 from zhijxu-MS/tmp_branch_for_PR
add leakyrelu rewriter and related test
2 parents 54b5b90 + f121b10 commit d949ceb

File tree

7 files changed

+70
-11
lines changed

7 files changed

+70
-11
lines changed

tests/test_backend.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -721,11 +721,13 @@ def test_relu(self):
721721
@skip_caffe2_backend("fails on caffe2 with dim issue")
722722
@check_onnxruntime_incompatibility("Mul")
723723
def test_leaky_relu(self):
724-
x_val = np.array([0.5, 1.0, -0.5, -1.0], dtype=np.float32).reshape((2, 2))
725-
x = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
726-
x_ = tf.nn.leaky_relu(x)
727-
_ = tf.identity(x_, name=_TFOUTPUT)
728-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
724+
for alpha in [0.1, -0.1, 1.0, -1.0, 10.0, -10.0]:
725+
x_val = 1000*np.random.random_sample([1000, 100]).astype(np.float32)
726+
x = tf.placeholder(tf.float32, [None]*x_val.ndim, name=_TFINPUT)
727+
x_ = tf.nn.leaky_relu(x, alpha)
728+
_ = tf.identity(x_, name=_TFOUTPUT)
729+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
730+
tf.reset_default_graph()
729731

730732
@check_onnxruntime_incompatibility("Elu")
731733
def test_elu(self):

tf2onnx/graph.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,8 +552,7 @@ def _push_stack(stack, node, in_stack):
552552
stack.append(node)
553553
if node in in_stack:
554554
raise ValueError('Graph has cycles.')
555-
else:
556-
in_stack[node] = True
555+
in_stack[node] = True
557556

558557
def _get_unvisited_child(g, node, not_visited):
559558
for child in g[node]:

tf2onnx/rewriter/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"cond_rewriter",
1111
"custom_rnn_rewriter",
1212
"gru_rewriter",
13+
"leakyrelu_rewriter",
1314
"loop_rewriter",
1415
"loop_rewriter_base",
1516
"lstm_rewriter",
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx leakyrelu op
6+
"""
7+
8+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
9+
10+
11+
# pylint: disable=missing-docstring
12+
13+
14+
def rewrite_leakyrelu(g, ops):
15+
if g.opset < 6:
16+
return ops
17+
18+
pattern = \
19+
OpTypePattern('Maximum', name='max', inputs=[
20+
OpTypePattern('Mul', name='mul', inputs=[
21+
OpTypePattern('Const', name='alpha'),
22+
OpTypePattern('*', name='mul_input'),
23+
]),
24+
OpTypePattern('*', name='max_input'),
25+
])
26+
27+
matcher = GraphMatcher(pattern, allow_reorder=True)
28+
match_results = list(matcher.match_ops(ops))
29+
for match in match_results:
30+
max_node = match.get_op('max')
31+
max_input_node = match.get_op('max_input')
32+
mul_node = match.get_op("mul")
33+
mul_input_node = match.get_op('mul_input')
34+
35+
max_input_edge_name = _find_edge_name_between_nodes(max_input_node, max_node)
36+
mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node)
37+
if max_input_edge_name == mul_input_edge_name:
38+
alpha = match.get_op("alpha").get_tensor_value()
39+
if alpha >= 1:
40+
continue
41+
leakyrelu = g.make_node("LeakyRelu", inputs=[max_input_edge_name], attr={"alpha": alpha},
42+
shapes=[g.get_shape(max_node.output[0])], dtypes=[g.get_dtype(max_node.output[0])])
43+
ops.remove(max_node)
44+
ops.remove(mul_node)
45+
ops.append(leakyrelu)
46+
g.replace_all_inputs(ops, max_node.output[0], leakyrelu.output[0])
47+
48+
return ops
49+
50+
51+
def _find_edge_name_between_nodes(src_node, consumer_node):
52+
# find the first edge connection between two nodes.
53+
for consumer_end in consumer_node.input:
54+
for src_end in src_node.output:
55+
if consumer_end == src_end:
56+
return consumer_end
57+
return None

tf2onnx/rewriter/random_uniform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the MIT license.
33

44
"""
5-
tf2onnx.rewrite - rewrite tensorflow subgraph to onnx random_uniform op
5+
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx random_uniform op
66
"""
77
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
88
from tf2onnx import utils

tf2onnx/shape_inference.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ def infer_shape_for_node(g, node):
9898
val = list(shape_attr.floats)
9999
if val:
100100
raise ValueError("placeholder shape has floats value, and not scalar value")
101-
else:
102-
new_shape = ()
101+
new_shape = ()
103102

104103
if new_shape is not None:
105104
g.set_shape(node.output[0], new_shape)

tf2onnx/tfonnx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2727
from tf2onnx.rewriter.cond_rewriter import rewrite_cond
2828
from tf2onnx.rewriter.random_uniform import rewrite_random_uniform, rewrite_random_uniform_fold_const
29+
from tf2onnx.rewriter.leakyrelu_rewriter import rewrite_leakyrelu
2930
from tf2onnx.rewriter.rnn import rewrite_bi_direction_gru
3031
from tf2onnx.rewriter.rnn import rewrite_custom_rnn_cell
3132
from tf2onnx.rewriter.rnn import rewrite_generic_loop
@@ -2475,7 +2476,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
24752476
# bi-directional re-writer should be placed after single directional re-writer
24762477
rewriters = [rewrite_transpose, rewrite_flatten,
24772478
rewrite_random_uniform, rewrite_random_uniform_fold_const,
2478-
rewrite_random_normal, rewrite_dropout,
2479+
rewrite_random_normal, rewrite_dropout, rewrite_leakyrelu,
24792480
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
24802481
rewrite_single_direction_gru, rewrite_single_direction_grublock,
24812482
rewrite_bi_direction_gru, rewrite_logical_compare_with_equal,

0 commit comments

Comments
 (0)