Skip to content

Commit 936cec6

Browse files
Make rewriters use get_tensor when needed to avoid subtle bugs (#1526)
Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent 48e9cc0 commit 936cec6

File tree

3 files changed

+9
-24
lines changed

3 files changed

+9
-24
lines changed

tf2onnx/rewriter/layer_normalization_rewriter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@ def rewrite_layer_normalization(g, ops):
8686
match_results = list(matcher.match_ops(ops))
8787
if match_results:
8888
for match in match_results:
89-
inp_node = match.get_op('input')
90-
rank = g.get_rank(inp_node.output[0])
89+
input_tensor = match.get_tensor('input')
90+
rank = g.get_rank(input_tensor)
9191
node = match.get_op('bias_add')
92-
if inp_node.name != match.get_op('input_r2').name or inp_node.name != match.get_op('input_r3').name:
92+
if input_tensor != match.get_tensor('input_r2') or input_tensor != match.get_tensor('input_r3'):
9393
continue
9494
if match.get_op('mean').name != match.get_op('mean_r2').name:
9595
continue
@@ -105,8 +105,8 @@ def rewrite_layer_normalization(g, ops):
105105
epsilon = match.get_op('epsilon').get_tensor_value(as_list=False).flatten().tolist()
106106
if len(epsilon) != 1:
107107
continue
108-
scale = match.get_op('scale').output[0]
109-
bias = match.get_op('bias').output[0]
108+
scale = match.get_tensor('scale')
109+
bias = match.get_tensor('bias')
110110
shape = g.make_node("Shape", [inp]).output[0]
111111
dim_2_shape = GraphBuilder(g).make_slice(
112112
{"data": shape, "ends": [2], "starts": [1], "axes": [0]})

tf2onnx/rewriter/leakyrelu_rewriter.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,10 @@ def rewrite_leakyrelu(g, ops):
2828
match_results = list(matcher.match_ops(ops))
2929
for match in match_results:
3030
max_node = match.get_op('max')
31-
max_input_node = match.get_op('max_input')
3231
mul_node = match.get_op("mul")
33-
mul_input_node = match.get_op('mul_input')
3432

35-
max_input_edge_name = _find_edge_name_between_nodes(max_input_node, max_node)
36-
mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node)
33+
max_input_edge_name = match.get_tensor('max_input')
34+
mul_input_edge_name = match.get_tensor('mul_input')
3735
if max_input_edge_name == mul_input_edge_name:
3836
alpha = match.get_op("alpha").get_tensor_value()
3937
if alpha >= 1:
@@ -46,12 +44,3 @@ def rewrite_leakyrelu(g, ops):
4644
g.safe_remove_nodes(to_delete)
4745

4846
return ops
49-
50-
51-
def _find_edge_name_between_nodes(src_node, consumer_node):
52-
# find the first edge connection between two nodes.
53-
for consumer_end in consumer_node.input:
54-
for src_end in src_node.output:
55-
if consumer_end == src_end:
56-
return consumer_end
57-
return None

tf2onnx/rewriter/thresholded_relu_rewriter.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77

88
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
9-
from tf2onnx.rewriter.leakyrelu_rewriter import _find_edge_name_between_nodes
109

1110

1211
# pylint: disable=missing-docstring
@@ -30,14 +29,11 @@ def rewrite_thresholded_relu(g, ops):
3029
match_results = list(matcher.match_ops(ops))
3130

3231
for match in match_results:
33-
greater_node = match.get_op('greater')
34-
greater_input_node = match.get_op('greater_input')
3532
mul_node = match.get_op("mul")
36-
mul_input_node = match.get_op('mul_input')
3733
cast_node = match.get_op('cast')
3834

39-
greater_input_edge_name = _find_edge_name_between_nodes(greater_input_node, greater_node)
40-
mul_input_edge_name = _find_edge_name_between_nodes(mul_input_node, mul_node)
35+
greater_input_edge_name = match.get_tensor('greater_input')
36+
mul_input_edge_name = match.get_tensor('mul_input')
4137
if greater_input_edge_name == mul_input_edge_name:
4238
theta = match.get_op('theta').get_tensor_value()
4339
thresholded_relu = g.make_node("ThresholdedRelu", inputs=[mul_input_edge_name], attr={"alpha": theta},

0 commit comments

Comments
 (0)