Skip to content

Commit 50259c9

Browse files
committed
Merge branch 'master' into gs/fix-bn
2 parents aecd66e + 5810313 commit 50259c9

File tree

8 files changed

+95
-11
lines changed

8 files changed

+95
-11
lines changed

tests/test_backend.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -721,11 +721,13 @@ def test_relu(self):
721721
@skip_caffe2_backend("fails on caffe2 with dim issue")
722722
@check_onnxruntime_incompatibility("Mul")
723723
def test_leaky_relu(self):
724-
x_val = np.array([0.5, 1.0, -0.5, -1.0], dtype=np.float32).reshape((2, 2))
725-
x = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
726-
x_ = tf.nn.leaky_relu(x)
727-
_ = tf.identity(x_, name=_TFOUTPUT)
728-
self._run_test_case([_OUTPUT], {_INPUT: x_val})
724+
for alpha in [0.1, -0.1, 1.0, -1.0, 10.0, -10.0]:
725+
x_val = 1000*np.random.random_sample([1000, 100]).astype(np.float32)
726+
x = tf.placeholder(tf.float32, [None]*x_val.ndim, name=_TFINPUT)
727+
x_ = tf.nn.leaky_relu(x, alpha)
728+
_ = tf.identity(x_, name=_TFOUTPUT)
729+
self._run_test_case([_OUTPUT], {_INPUT: x_val})
730+
tf.reset_default_graph()
729731

730732
@check_onnxruntime_incompatibility("Elu")
731733
def test_elu(self):

tf2onnx/graph.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,8 +552,7 @@ def _push_stack(stack, node, in_stack):
552552
stack.append(node)
553553
if node in in_stack:
554554
raise ValueError('Graph has cycles.')
555-
else:
556-
in_stack[node] = True
555+
in_stack[node] = True
557556

558557
def _get_unvisited_child(g, node, not_visited):
559558
for child in g[node]:

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT license.
3+
34
"""Transpose Optimizer."""
45

56
from __future__ import unicode_literals
7+
from collections import defaultdict
68

79
import logging
810

@@ -113,6 +115,27 @@ def post_optimize_action(self):
113115
self._g.update_proto()
114116
self._g.topological_sort(self._g.get_nodes())
115117

118+
def merge_duplicated_transposes(self):
119+
# strategy used in previous procedure is to move transpose nodes down if possible,
120+
# and it means that when a node has n outputs then n transpose will be generated,
121+
# so we should merge them back to one if they can't be eliminated in previous procedure.
122+
graph = self._g
123+
input_transposes_map = defaultdict(list)
124+
for node in graph.get_nodes():
125+
if node.type == "Transpose":
126+
key = (node.input[0], str(node.get_attr("perm").ints))
127+
input_transposes_map[key].append(node)
128+
129+
for transposes in input_transposes_map.values():
130+
# merge transpose nodes into one: make nodes use the output of the first transpose node
131+
transpose_out = transposes[0].output[0]
132+
for node in transposes[1:]:
133+
old_transpose_out = node.output[0]
134+
graph.replace_all_inputs(graph.get_nodes(), old_transpose_out, transpose_out)
135+
136+
# dangling transpose nodes can be deleted
137+
graph.delete_unused_nodes(graph.outputs)
138+
116139
def optimize(self):
117140
previous_counter = self._g.dump_node_statistics()
118141
no_action = False
@@ -140,6 +163,8 @@ def optimize(self):
140163
break
141164

142165
log.debug("finish after " + str(iteration_cnt) + " iteration(s)")
166+
167+
self.merge_duplicated_transposes()
143168
self.post_optimize_action()
144169

145170
current_counter = self._g.dump_node_statistics()

tf2onnx/rewriter/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"cond_rewriter",
1111
"custom_rnn_rewriter",
1212
"gru_rewriter",
13+
"leakyrelu_rewriter",
1314
"loop_rewriter",
1415
"loop_rewriter_base",
1516
"lstm_rewriter",
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx leakyrelu op
6+
"""
7+
8+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
9+
10+
11+
# pylint: disable=missing-docstring
12+
13+
14+
def rewrite_leakyrelu(g, ops):
15+
if g.opset < 6:
16+
return ops
17+
18+
pattern = \
19+
OpTypePattern('Maximum', name='max', inputs=[
20+
OpTypePattern('Mul', name='mul', inputs=[
21+
OpTypePattern('Const', name='alpha'),
22+
OpTypePattern('*', name='mul_input'),
23+
]),
24+
OpTypePattern('*', name='max_input'),
25+
])
26+
27+
matcher = GraphMatcher(pattern, allow_reorder=True)
28+
match_results = list(matcher.match_ops(ops))
29+
for match in match_results:
30+
max_node = match.get_op('max')
31+
max_input_node = match.get_op('max_input')
32+
mul_node = match.get_op("mul")
33+
mul_input_node = match.get_op('mul_input')
34+
35+
max_input_edge_name = _find_edge_name_between_nodes(max_input_node, max_node)
36+
mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node)
37+
if max_input_edge_name == mul_input_edge_name:
38+
alpha = match.get_op("alpha").get_tensor_value()
39+
if alpha >= 1:
40+
continue
41+
leakyrelu = g.make_node("LeakyRelu", inputs=[max_input_edge_name], attr={"alpha": alpha},
42+
shapes=[g.get_shape(max_node.output[0])], dtypes=[g.get_dtype(max_node.output[0])])
43+
ops.remove(max_node)
44+
ops.remove(mul_node)
45+
ops.append(leakyrelu)
46+
g.replace_all_inputs(ops, max_node.output[0], leakyrelu.output[0])
47+
48+
return ops
49+
50+
51+
def _find_edge_name_between_nodes(src_node, consumer_node):
52+
# find the first edge connection between two nodes.
53+
for consumer_end in consumer_node.input:
54+
for src_end in src_node.output:
55+
if consumer_end == src_end:
56+
return consumer_end
57+
return None

tf2onnx/rewriter/random_uniform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the MIT license.
33

44
"""
5-
tf2onnx.rewrite - rewrite tensorflow subgraph to onnx random_uniform op
5+
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx random_uniform op
66
"""
77
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
88
from tf2onnx import utils

tf2onnx/shape_inference.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ def infer_shape_for_node(g, node):
9898
val = list(shape_attr.floats)
9999
if val:
100100
raise ValueError("placeholder shape has floats value, and not scalar value")
101-
else:
102-
new_shape = ()
101+
new_shape = ()
103102

104103
if new_shape is not None:
105104
g.set_shape(node.output[0], new_shape)

tf2onnx/tfonnx.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2727
from tf2onnx.rewriter.cond_rewriter import rewrite_cond
2828
from tf2onnx.rewriter.random_uniform import rewrite_random_uniform, rewrite_random_uniform_fold_const
29+
from tf2onnx.rewriter.leakyrelu_rewriter import rewrite_leakyrelu
2930
from tf2onnx.rewriter.rnn import rewrite_bi_direction_gru
3031
from tf2onnx.rewriter.rnn import rewrite_custom_rnn_cell
3132
from tf2onnx.rewriter.rnn import rewrite_generic_loop
@@ -2482,7 +2483,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
24822483
# bi-directional re-writer should be placed after single directional re-writer
24832484
rewriters = [rewrite_transpose, rewrite_flatten,
24842485
rewrite_random_uniform, rewrite_random_uniform_fold_const,
2485-
rewrite_random_normal, rewrite_dropout,
2486+
rewrite_random_normal, rewrite_dropout, rewrite_leakyrelu,
24862487
rewrite_single_direction_lstm, rewrite_bi_direction_lstm,
24872488
rewrite_single_direction_gru, rewrite_single_direction_grublock,
24882489
rewrite_bi_direction_gru, rewrite_logical_compare_with_equal,

0 commit comments

Comments
 (0)