Skip to content

Commit f74b167

Browse files
author
wayuanho
committed
decouple rewriter
1 parent 4d0fced commit f74b167

File tree

9 files changed

+313
-239
lines changed

9 files changed

+313
-239
lines changed

test.onnx

404 Bytes
Binary file not shown.

tf2onnx/rewriter/__init__.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,38 @@
77
from __future__ import unicode_literals
88

99
from tf2onnx.rewriter.cond_rewriter import rewrite_cond
10-
from tf2onnx.rewriter.random_uniform import rewrite_random_uniform, rewrite_random_uniform_fold_const
11-
from tf2onnx.rewriter.leakyrelu_rewriter import rewrite_leakyrelu
12-
from tf2onnx.rewriter.gemm_rewriter import rewrite_gemm
10+
from tf2onnx.rewriter.conv2d_with_pad_rewriter import rewrite_conv2d_with_pad
11+
from tf2onnx.rewriter.dropout_rewriter import rewrite_dropout
1312
from tf2onnx.rewriter.eye_rewriter import rewrite_eye
14-
from tf2onnx.rewriter.thresholded_relu_rewriter import rewrite_thresholded_relu
13+
from tf2onnx.rewriter.flatten_rewriter import rewrite_flatten
14+
from tf2onnx.rewriter.gemm_rewriter import rewrite_gemm
15+
from tf2onnx.rewriter.leakyrelu_rewriter import rewrite_leakyrelu
16+
from tf2onnx.rewriter.random_normal_rewriter import rewrite_random_normal
17+
from tf2onnx.rewriter.random_uniform import rewrite_random_uniform, rewrite_random_uniform_fold_const
1518
from tf2onnx.rewriter.rnn import rewrite_single_direction_lstm, rewrite_bi_direction_lstm, \
1619
rewrite_single_direction_gru, rewrite_bi_direction_gru, \
1720
rewrite_custom_rnn_cell, rewrite_generic_loop
21+
from tf2onnx.rewriter.thresholded_relu_rewriter import rewrite_thresholded_relu
22+
from tf2onnx.rewriter.transpose_rewriter import rewrite_transpose
1823

1924
__all__ = [
2025
"rewrite_cond",
26+
"rewrite_conv2d_with_pad",
27+
"rewrite_dropout",
28+
"rewrite_eye",
29+
"rewrite_flatten",
30+
"rewrite_gemm",
31+
"rewrite_leakyrelu",
32+
"rewrite_random_normal",
2133
"rewrite_random_uniform",
2234
"rewrite_random_uniform_fold_const",
23-
"rewrite_leakyrelu",
2435
"rewrite_thresholded_relu",
25-
"rewrite_eye",
36+
"rewrite_transpose",
37+
2638
"rewrite_single_direction_lstm",
2739
"rewrite_bi_direction_lstm",
2840
"rewrite_single_direction_gru",
2941
"rewrite_bi_direction_gru",
3042
"rewrite_custom_rnn_cell",
31-
"rewrite_gemm",
32-
"rewrite_generic_loop"
43+
"rewrite_generic_loop",
3344
]
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx condv2 op with pad
6+
"""
7+
8+
import numpy as np
9+
10+
from tf2onnx import handler, logging
11+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
# pylint: disable=missing-docstring
17+
18+
19+
def rewrite_conv2d_with_pad(g, ops):
20+
pattern = \
21+
OpTypePattern("Conv2D", name="conv", inputs=[
22+
OpTypePattern("Pad", name="pad"),
23+
OpTypePattern("*")
24+
])
25+
matcher = GraphMatcher(pattern)
26+
match_results = list(matcher.match_ops(ops))
27+
for match in match_results:
28+
conv = match.get_op("conv")
29+
pad = match.get_op("pad")
30+
paddings = pad.inputs[1]
31+
32+
if not paddings.is_const():
33+
continue
34+
mode = pad.get_attr("mode")
35+
if mode:
36+
mode = mode.s.decode("utf-8").lower()
37+
if mode not in [None, "constant"] or len(pad.input) >= 3:
38+
continue
39+
# Conv2D already has a pad
40+
if conv.get_attr("padding") == "SAME":
41+
continue
42+
43+
logger.debug("merge pad [%s] into conv [%s]", pad.name, conv.name)
44+
paddings_val = np.array(paddings.get_tensor_value())
45+
# can't pad on batch or channel dimensions
46+
if np.any(paddings_val[0]) or np.any(paddings_val[3]):
47+
continue
48+
49+
paddings_val = paddings_val[1:3]
50+
paddings_val = paddings_val.transpose().flatten()
51+
g.replace_input(conv, conv.input[0], pad.input[0])
52+
# convert Conv2D
53+
conv.type = "Conv"
54+
func, _ = handler.tf_op.find_effective_op("Conv2D")
55+
func(g, conv)
56+
conv.skip_conversion = True
57+
conv.set_attr("auto_pad", "NOTSET")
58+
conv.set_attr("pads", paddings_val)
59+
return ops

tf2onnx/rewriter/dropout_rewriter.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx dropout op
6+
"""
7+
8+
from tf2onnx import utils
9+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
10+
11+
12+
# pylint: disable=missing-docstring
13+
14+
15+
def rewrite_dropout(g, ops):
16+
patterns = [
17+
OpTypePattern('Mul', name='outputs', inputs=[
18+
OpTypePattern('RealDiv', name="input2"),
19+
OpTypePattern('Floor', inputs=[
20+
OpTypePattern('Add', inputs=[
21+
OpTypePattern(None, name="input3"),
22+
OpTypePattern('RandomUniform|RandomUniformLike'),
23+
])
24+
]),
25+
]),
26+
OpTypePattern("Mul", name="outputs", inputs=[
27+
OpTypePattern("Mul", name="input2"),
28+
OpTypePattern("Cast", inputs=[
29+
OpTypePattern("GreaterEqual", inputs=[
30+
OpTypePattern("RandomUniform|RandomUniformLike"),
31+
OpTypePattern(None, name="input3")
32+
])
33+
])
34+
])
35+
]
36+
for pattern in patterns:
37+
matcher = GraphMatcher(pattern)
38+
match_results = list(matcher.match_ops(ops))
39+
for match in match_results:
40+
inputs2 = match.get_op('input2')
41+
outputs = match.get_op('outputs')
42+
op_name = utils.make_name("Dropout")
43+
out_name = utils.port_name(op_name)
44+
new_node = g.make_node(
45+
"Dropout",
46+
[inputs2.input[0]],
47+
outputs=[out_name],
48+
name=op_name,
49+
attr={"ratio": 1.0},
50+
shapes=[g.get_shape(inputs2.input[0])],
51+
dtypes=[g.get_dtype(inputs2.input[0])]
52+
)
53+
g.replace_all_inputs(ops, outputs.output[0], new_node.output[0])
54+
g.safe_remove_nodes(match.get_nodes())
55+
56+
# remove dropout if its ratio is 1.0
57+
for node in g.get_nodes():
58+
if node.type == "Dropout" and node.get_attr("ratio").f == 1.0:
59+
g.replace_all_inputs(g.get_nodes(), node.output[0], node.input[0])
60+
g.remove_node(node.name)
61+
62+
return ops

tf2onnx/rewriter/flatten_rewriter.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx flatten op
6+
"""
7+
8+
import numpy as np
9+
10+
from tf2onnx import utils
11+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
12+
13+
14+
# pylint: disable=missing-docstring
15+
16+
17+
def rewrite_flatten(g, ops):
18+
pattern_fixed_shape_input = \
19+
OpTypePattern('Reshape', name='reshape', inputs=[
20+
OpTypePattern("*", name="input"),
21+
OpTypePattern('Pack', name="pack", inputs=[
22+
OpTypePattern('StridedSlice', name="slice", inputs=[
23+
"*", "*", "*", "*",
24+
]),
25+
"*",
26+
]),
27+
])
28+
pattern_non_fixed_shape_input = \
29+
OpTypePattern('Reshape', name='reshape', inputs=[
30+
OpTypePattern("*", name="input"),
31+
OpTypePattern('Pack', name="pack", inputs=[
32+
OpTypePattern('StridedSlice', name="slice", inputs=[
33+
OpTypePattern('Shape', inputs=[
34+
OpTypePattern("*", name="input2")
35+
]),
36+
"*", "*", "*",
37+
]),
38+
"*",
39+
]),
40+
])
41+
matcher = GraphMatcher(pattern_fixed_shape_input)
42+
match_results_1 = list(matcher.match_ops(ops))
43+
44+
matcher = GraphMatcher(pattern_non_fixed_shape_input)
45+
match_results_2 = list(matcher.match_ops(ops))
46+
47+
match_results = [(match_results_1, True), (match_results_2, False)]
48+
for match_results, check_fixed_input_shape in match_results:
49+
for match in match_results:
50+
input_node = match.get_op('input')
51+
reshape_node = match.get_op('reshape')
52+
pack_node = match.get_op('pack')
53+
slice_node = match.get_op('slice')
54+
need_rewrite = pack_node.inputs[1].is_const() and pack_node.inputs[1].get_tensor_value() == -1
55+
if not need_rewrite:
56+
continue
57+
58+
input_shape = g.get_shape(reshape_node.input[0])
59+
need_rewrite = input_shape is not None
60+
if not need_rewrite:
61+
continue
62+
63+
if check_fixed_input_shape:
64+
need_rewrite = slice_node.inputs[0].is_const() and \
65+
np.array_equal(list(input_shape), list(slice_node.inputs[0].get_tensor_value()))
66+
if not need_rewrite:
67+
continue
68+
69+
begin = slice_node.inputs[1].get_tensor_value(as_list=False)
70+
end = slice_node.inputs[2].get_tensor_value(as_list=False)
71+
strides = slice_node.inputs[3].get_tensor_value(as_list=False)
72+
need_rewrite = np.array_equal(begin, [0]) and len(end) == 1 and \
73+
np.array_equal(strides, [1]) and end[0] - begin[0] == 1
74+
if not need_rewrite:
75+
continue
76+
77+
op_name = utils.make_name("Flatten")
78+
out_name = utils.port_name(op_name)
79+
g.make_node("Flatten", [reshape_node.input[0]], outputs=[out_name], name=op_name)
80+
81+
last_dim = input_shape[-1]
82+
sec_last_dim = input_shape[-2]
83+
new_dim = None
84+
if last_dim > 0 and sec_last_dim > 0:
85+
new_dim = last_dim * sec_last_dim
86+
else:
87+
new_dim = -1
88+
89+
g.set_shape(out_name, input_shape[:-2] + [new_dim])
90+
g.replace_all_inputs(ops, reshape_node.output[0], out_name)
91+
to_delete = [n for n in match.get_nodes() if n != input_node]
92+
g.safe_remove_nodes(to_delete)
93+
94+
return ops
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewriter - rewrite tensorflow subgraph to onnx random normal op
6+
"""
7+
8+
from tf2onnx import utils
9+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
10+
11+
12+
# pylint: disable=missing-docstring
13+
14+
15+
def rewrite_random_normal(g, ops):
16+
pattern = \
17+
OpTypePattern('Add', name='output', inputs=[
18+
OpTypePattern('Mul', name='input2', inputs=[
19+
OpTypePattern('RandomStandardNormal', name='input1', inputs=["*"]), "*"
20+
]), "*"
21+
])
22+
23+
matcher = GraphMatcher(pattern)
24+
match_results = list(matcher.match_ops(ops))
25+
for match in match_results:
26+
output = match.get_op('output')
27+
mean = output.inputs[1].get_tensor_value()
28+
dtype = g.get_dtype(output.output[0])
29+
op_name = utils.make_name("RandomNormal")
30+
out_name = utils.port_name(op_name)
31+
32+
rn_op = match.get_op('input1')
33+
if rn_op.inputs[0].type == "Shape":
34+
shape_node = rn_op.inputs[0]
35+
new_node = g.make_node("RandomNormalLike", [shape_node.input[0]], outputs=[out_name], name=op_name,
36+
attr={"mean": mean, "scale": 1.0, "dtype": dtype})
37+
else:
38+
shape = g.get_shape(output.output[0])
39+
new_node = g.make_node("RandomNormal", [], outputs=[out_name], name=op_name,
40+
attr={"shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype})
41+
42+
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
43+
g.safe_remove_nodes(match.get_nodes())
44+
return ops

tf2onnx/rewriter/random_uniform.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ def rewrite_random_uniform(g, ops):
3333
to_delete = list(set(match.get_nodes()))
3434
new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete)
3535
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
36-
for n in to_delete:
37-
g.remove_node(n.name)
36+
g.safe_remove_nodes(to_delete)
3837

3938
return ops
4039

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.rewriter - rewrite tensorflow transpose op
6+
"""
7+
8+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
9+
10+
11+
# pylint: disable=missing-docstring
12+
13+
14+
def rewrite_transpose(g, ops):
15+
pattern = \
16+
OpTypePattern('Transpose', name='output', inputs=[
17+
OpTypePattern(None),
18+
OpTypePattern('Sub', inputs=[
19+
OpTypePattern('Sub', inputs=["*", "*"]),
20+
OpTypePattern('Range', inputs=["*", "*", "*"]),
21+
]),
22+
])
23+
24+
matcher = GraphMatcher(pattern)
25+
match_results = list(matcher.match_ops(ops))
26+
for match in match_results:
27+
output = match.get_op('output')
28+
shape = g.get_shape(output.input[0])
29+
dims = [i for i in range(len(shape) - 1, -1, -1)]
30+
output.set_attr("perm", dims)
31+
g.remove_input(output, output.input[1])
32+
to_delete = [n for n in match.get_nodes() if n != output]
33+
g.safe_remove_nodes(to_delete)
34+
return ops

0 commit comments

Comments
 (0)