Skip to content

Commit f121b10

Browse files
committed
code refactor
1 parent 7798a2f commit f121b10

File tree

5 files changed

+17
-19
lines changed

5 files changed

+17
-19
lines changed

tests/test_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -721,9 +721,9 @@ def test_relu(self):
721721
@skip_caffe2_backend("fails on caffe2 with dim issue")
722722
@check_onnxruntime_incompatibility("Mul")
723723
def test_leaky_relu(self):
724-
for alpha in [0.1, -0.1]:
725-
x_val = np.array([0.5, 1.0, -0.5, -1.0], dtype=np.float32).reshape((2, 2))
726-
x = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT)
724+
for alpha in [0.1, -0.1, 1.0, -1.0, 10.0, -10.0]:
725+
x_val = 1000*np.random.random_sample([1000, 100]).astype(np.float32)
726+
x = tf.placeholder(tf.float32, [None]*x_val.ndim, name=_TFINPUT)
727727
x_ = tf.nn.leaky_relu(x, alpha)
728728
_ = tf.identity(x_, name=_TFOUTPUT)
729729
self._run_test_case([_OUTPUT], {_INPUT: x_val})

tf2onnx/graph.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,8 +552,7 @@ def _push_stack(stack, node, in_stack):
552552
stack.append(node)
553553
if node in in_stack:
554554
raise ValueError('Graph has cycles.')
555-
else:
556-
in_stack[node] = True
555+
in_stack[node] = True
557556

558557
def _get_unvisited_child(g, node, not_visited):
559558
for child in g[node]:

tf2onnx/rewriter/leakyrelu_rewriter.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the MIT license.
33

44
"""
5-
tf2onnx.rewrite - rewrite tensorflow subgraph to onnx leakyrelu op
5+
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx leakyrelu op
66
"""
77

88
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
@@ -32,13 +32,13 @@ def rewrite_leakyrelu(g, ops):
3232
mul_node = match.get_op("mul")
3333
mul_input_node = match.get_op('mul_input')
3434

35-
max_input_edge_name = _find_edges_name_btw_nodes(max_input_node, max_node)
36-
mul_input_edge_name = _find_edges_name_btw_nodes(mul_input_node, mul_node)
35+
max_input_edge_name = _find_edge_name_between_nodes(max_input_node, max_node)
36+
mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node)
3737
if max_input_edge_name == mul_input_edge_name:
3838
alpha = match.get_op("alpha").get_tensor_value()
3939
if alpha >= 1:
4040
continue
41-
leakyrelu = g.make_node("LeakyRelu", inputs=max_input_edge_name, attr={"alpha": alpha},
41+
leakyrelu = g.make_node("LeakyRelu", inputs=[max_input_edge_name], attr={"alpha": alpha},
4242
shapes=[g.get_shape(max_node.output[0])], dtypes=[g.get_dtype(max_node.output[0])])
4343
ops.remove(max_node)
4444
ops.remove(mul_node)
@@ -48,10 +48,10 @@ def rewrite_leakyrelu(g, ops):
4848
return ops
4949

5050

51-
def _find_edges_name_btw_nodes(sender, sinker):
52-
res = []
53-
for sinker_end in sinker.input:
54-
for sender_end in sender.output:
55-
if sinker_end == sender_end:
56-
res.append(sinker_end)
57-
return res
51+
def _find_edge_name_between_nodes(src_node, consumer_node):
52+
# find the first edge connection between two nodes.
53+
for consumer_end in consumer_node.input:
54+
for src_end in src_node.output:
55+
if consumer_end == src_end:
56+
return consumer_end
57+
return None

tf2onnx/rewriter/random_uniform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# Licensed under the MIT license.
33

44
"""
5-
tf2onnx.rewrite - rewrite tensorflow subgraph to onnx random_uniform op
5+
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx random_uniform op
66
"""
77
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
88
from tf2onnx import utils

tf2onnx/shape_inference.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,7 @@ def infer_shape_for_node(g, node):
9898
val = list(shape_attr.floats)
9999
if val:
100100
raise ValueError("placeholder shape has floats value, and not scalar value")
101-
else:
102-
new_shape = ()
101+
new_shape = ()
103102

104103
if new_shape is not None:
105104
g.set_shape(node.output[0], new_shape)

0 commit comments

Comments
 (0)