Skip to content

Commit e5b1b2a

Browse files
Match SpaceToBatch->Conv->BatchToSpace pattern more generally (#1468)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 75568a7 commit e5b1b2a

File tree

5 files changed

+176
-67
lines changed

5 files changed

+176
-67
lines changed

tests/test_backend.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,64 @@ def func(x):
730730
process_args={"inputs_as_nchw": [_INPUT]},
731731
onnx_feed_dict={_INPUT: x_val_for_onnx})
732732

733+
@skip_tflite("TFlite adds ops that obscure pattern")
734+
@check_tf_min_version("2.0")
735+
def test_conv1d_dilations_rewriter(self):
736+
x_shape = [2, 32, 3]
737+
x_val = make_xval(x_shape)
738+
for p in ['SAME', 'VALID']:
739+
def func(x):
740+
t = tf.keras.layers.Conv1D(filters=768, kernel_size=3, dilation_rate=3, padding=p)
741+
t.build(x_shape)
742+
y = t.call(x)
743+
return tf.identity(y, name=_TFOUTPUT)
744+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
745+
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
746+
747+
@check_tf_min_version("2.0")
748+
def test_conv2d_dilations_rewriter(self):
749+
x_shape = [2, 32, 16, 3]
750+
x_val = make_xval(x_shape)
751+
for p in ['SAME', 'VALID']:
752+
def func(x):
753+
t = tf.keras.layers.Conv2D(filters=768, kernel_size=3, dilation_rate=3, padding=p)
754+
t.build(x_shape)
755+
y = t.call(x)
756+
return tf.identity(y, name=_TFOUTPUT)
757+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
758+
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
759+
def func(x):
760+
t = tf.keras.layers.DepthwiseConv2D(kernel_size=3, dilation_rate=3, padding=p)
761+
t.build(x_shape)
762+
y = t.call(x)
763+
return tf.identity(y, name=_TFOUTPUT)
764+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
765+
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
766+
767+
@check_tf_min_version("2.0")
768+
def test_conv3d_dilations_rewriter(self):
769+
x_shape = [2, 32, 16, 8, 3]
770+
x_val = make_xval(x_shape)
771+
for p in ['SAME', 'VALID']:
772+
def func(x):
773+
t = tf.keras.layers.Conv3D(filters=768, kernel_size=3, dilation_rate=3, padding=p)
774+
t.build(x_shape)
775+
y = t.call(x)
776+
return tf.identity(y, name=_TFOUTPUT)
777+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
778+
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
779+
780+
@skip_tf2("Uses tf.layers")
781+
def test_conv1d_tf1_dilations_rewriter(self):
782+
x_shape = [2, 32, 3]
783+
x_val = make_xval(x_shape)
784+
for p in ['SAME', 'VALID']:
785+
def func(x):
786+
y = tf.layers.conv1d(x, filters=768, kernel_size=3, dilation_rate=3, padding=p, name="conv1")
787+
return tf.identity(y, name=_TFOUTPUT)
788+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
789+
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
790+
733791
def test_lrn_default(self):
734792
x_shape = [1, 3, 4, 3]
735793
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)

tf2onnx/rewriter/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from tf2onnx.rewriter.cond_rewriter import rewrite_cond
1010
from tf2onnx.rewriter.conv2d_with_pad_rewriter import rewrite_conv2d_with_pad
11-
from tf2onnx.rewriter.depthwise_conv_dilations_rewriter import rewrite_depthwise_conv_dilations
11+
from tf2onnx.rewriter.conv_dilations_rewriter import rewrite_conv_dilations
1212
from tf2onnx.rewriter.dropout_rewriter import rewrite_dropout
1313
from tf2onnx.rewriter.eye_rewriter import rewrite_eye
1414
from tf2onnx.rewriter.flatten_rewriter import rewrite_flatten
@@ -48,5 +48,5 @@
4848
"rewrite_biasadd_with_conv2d",
4949
"rewrite_quantize_and_dequantize",
5050
"rewrite_layer_normalization",
51-
"rewrite_depthwise_conv_dilations"
51+
"rewrite_conv_dilations"
5252
]
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""
5+
tf2onnx.rewriter.conv_dilations_rewriter - Rewrites the patten used to represent dilations
6+
pat = SpaceToBatchND->DepthwiseConv2dNative->BatchToSpaceND
7+
"""
8+
9+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
10+
11+
# pylint: disable=invalid-name,unused-argument,missing-docstring, unused-variable
12+
13+
14+
def rewrite_conv_dilations(g, ops):
15+
pattern1 = \
16+
OpTypePattern("BatchToSpaceND", name="batch_to_space", inputs=[
17+
OpTypePattern("DepthwiseConv2dNative|Conv2D|Conv3D", name="conv", inputs=[
18+
OpTypePattern("SpaceToBatchND", name="space_to_batch", inputs=[
19+
OpTypePattern("*"),
20+
OpTypePattern("Const|ConstV2"),
21+
OpTypePattern("Const|ConstV2"),
22+
]),
23+
OpTypePattern("*"),
24+
]),
25+
OpTypePattern("Const|ConstV2"),
26+
OpTypePattern("Const|ConstV2"),
27+
])
28+
pattern2 = \
29+
OpTypePattern("BatchToSpaceND", name="batch_to_space", inputs=[
30+
OpTypePattern("Squeeze", name="squeeze", inputs=[
31+
OpTypePattern("DepthwiseConv2dNative|Conv2D|Conv3D", name="conv", inputs=[
32+
OpTypePattern("ExpandDims", name="expand", inputs=[
33+
OpTypePattern("SpaceToBatchND", name="space_to_batch", inputs=[
34+
OpTypePattern("*"),
35+
OpTypePattern("Const|ConstV2"),
36+
OpTypePattern("Const|ConstV2"),
37+
]),
38+
OpTypePattern("Const|ConstV2"),
39+
]),
40+
OpTypePattern("*"),
41+
]),
42+
]),
43+
OpTypePattern("Const|ConstV2"),
44+
OpTypePattern("Const|ConstV2"),
45+
])
46+
47+
for pattern in [pattern1, pattern2]:
48+
matcher = GraphMatcher(pattern, allow_reorder=False)
49+
match_results = list(matcher.match_ops(ops))
50+
for match_result in match_results:
51+
is_conv_1d = pattern is pattern2
52+
space_to_batch = match_result.get_op("space_to_batch")
53+
conv = match_result.get_op("conv")
54+
batch_to_space = match_result.get_op("batch_to_space")
55+
if is_conv_1d:
56+
expand = match_result.get_op("expand")
57+
expand_axis = expand.inputs[1].get_tensor_value(as_list=True)
58+
squeeze = match_result.get_op("squeeze")
59+
squeeze_axes = squeeze.get_attr_value("squeeze_dims")
60+
if expand_axis not in [1, -3] or squeeze_axes not in [[1], [-3]]:
61+
continue
62+
63+
block_shape1 = space_to_batch.inputs[1].get_tensor_value(as_list=True)
64+
paddings = space_to_batch.inputs[2].get_tensor_value(as_list=True)
65+
block_shape2 = batch_to_space.inputs[1].get_tensor_value(as_list=True)
66+
crops = batch_to_space.inputs[2].get_tensor_value(as_list=True)
67+
68+
if block_shape1 != block_shape2:
69+
continue
70+
ndims = 2 if is_conv_1d else len(block_shape1)
71+
data_format = b"NHWC" if ndims == 2 else b"NDHWC"
72+
ones = [1] * (ndims + 2)
73+
if conv.get_attr_value("dilations", ones) != ones:
74+
continue
75+
if conv.get_attr_value("strides", ones) != ones:
76+
continue
77+
if conv.get_attr_value("data_format", data_format) != data_format:
78+
continue
79+
if conv.get_attr_value("padding") != b"VALID":
80+
continue
81+
82+
83+
base_start_pad = [p[0] for p in paddings]
84+
if any(c[0] != 0 for c in crops):
85+
continue
86+
base_end_pad = [p[1] - c[1] for p, c in zip(paddings, crops)]
87+
if not all(0 <= p[1] - bp < bs for p, bp, bs in zip(paddings, base_end_pad, block_shape1)):
88+
continue
89+
90+
if is_conv_1d:
91+
inp = space_to_batch.input[0]
92+
g.replace_inputs(expand, [inp, expand.input[1]])
93+
g.copy_shape(batch_to_space.output[0], squeeze.output[0])
94+
g.replace_all_inputs(batch_to_space.output[0], squeeze.output[0])
95+
squeeze_out_shape = g.get_shape(squeeze.output[0])
96+
g.set_shape(squeeze.input[0], squeeze_out_shape[:1] + [1] + squeeze_out_shape[1:])
97+
expand_inp_shape = g.get_shape(expand.input[0])
98+
g.set_shape(expand.output[0], expand_inp_shape[:1] + [1] + expand_inp_shape[1:])
99+
100+
base_start_pad = [0] + base_start_pad
101+
base_end_pad = [0] + base_end_pad
102+
block_shape1 = [1] + block_shape1
103+
else:
104+
inp = space_to_batch.input[0]
105+
kernel = conv.input[1]
106+
g.replace_inputs(conv, [inp, kernel])
107+
g.copy_shape(batch_to_space.output[0], conv.output[0])
108+
g.replace_all_inputs(batch_to_space.output[0], conv.output[0])
109+
110+
base_pad_flat = [0, 0] + [x for s, e in zip(base_start_pad, base_end_pad) for x in [s, e]] + [0, 0]
111+
conv.set_attr("dilations", [1] + block_shape1 + [1])
112+
conv.set_attr("explicit_paddings", base_pad_flat)
113+
conv.set_attr("padding", "EXPLICIT")
114+
115+
return g.get_nodes()

tf2onnx/rewriter/depthwise_conv_dilations_rewriter.py

Lines changed: 0 additions & 64 deletions
This file was deleted.

tf2onnx/tfonnx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ def compat_handler(ctx, node, **kwargs):
605605
rewrite_random_uniform_fold_const,
606606
rewrite_random_normal,
607607
rewrite_dropout,
608-
rewrite_depthwise_conv_dilations,
608+
rewrite_conv_dilations,
609609
rewrite_eye,
610610
rewrite_leakyrelu,
611611
rewrite_thresholded_relu,

0 commit comments

Comments
 (0)