Skip to content

Commit f14d5a7

Browse files
Optimize SpaceToBatchND sequence to DepthwiseConv2dNative dilations (#1461)
* Add support for Explicit padding Signed-off-by: Tom Wildenhain <[email protected]> * Optimize SpaceToBatchND sequence to DepthwiseConv2dNative dilations Signed-off-by: Tom Wildenhain <[email protected]> * pylint Signed-off-by: Tom Wildenhain <[email protected]>
1 parent c5a78fc commit f14d5a7

File tree

4 files changed

+80
-1
lines changed

4 files changed

+80
-1
lines changed

tests/test_backend.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3411,6 +3411,18 @@ def func(x, z):
34113411
return batch_to_space_nd(x, y, z, name=_TFOUTPUT)
34123412
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT2: z_val})
34133413

3414+
def test_depthwise_dilations_pattern(self):
3415+
x_val = np.random.random_sample([1, 33, 34, 960]).astype(np.float32)
3416+
kernel = np.random.random_sample([3, 3, 960, 1]).astype(np.float32)
3417+
block_size = np.array([3, 3], np.int64)
3418+
pad = np.array([[2, 4], [5, 3]])
3419+
crop = np.array([[0, 0], [0, 0]])
3420+
def func(x):
3421+
y = space_to_batch_nd(x, block_size, pad)
3422+
z = tf.nn.depthwise_conv2d(y, kernel, strides=[1, 1, 1, 1], padding='VALID')
3423+
return batch_to_space_nd(z, block_size, crop, name=_TFOUTPUT)
3424+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3425+
34143426
@check_opset_min_version(11, "SpaceToBatchND")
34153427
def test_space_to_batchnd_non_const_7d(self):
34163428
x_type, y_type, z_type = np.float32, np.int64, np.int64

tf2onnx/rewriter/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from tf2onnx.rewriter.cond_rewriter import rewrite_cond
1010
from tf2onnx.rewriter.conv2d_with_pad_rewriter import rewrite_conv2d_with_pad
11+
from tf2onnx.rewriter.depthwise_conv_dilations_rewriter import rewrite_depthwise_conv_dilations
1112
from tf2onnx.rewriter.dropout_rewriter import rewrite_dropout
1213
from tf2onnx.rewriter.eye_rewriter import rewrite_eye
1314
from tf2onnx.rewriter.flatten_rewriter import rewrite_flatten
@@ -46,5 +47,6 @@
4647
"rewrite_generic_loop",
4748
"rewrite_biasadd_with_conv2d",
4849
"rewrite_quantize_and_dequantize",
49-
"rewrite_layer_normalization"
50+
"rewrite_layer_normalization",
51+
"rewrite_depthwise_conv_dilations"
5052
]
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""
5+
tf2onnx.rewriter.depthwise_conv_dilations_rewriter - Rewrites the patten used to represent dilations
6+
pat = SpaceToBatchND->DepthwiseConv2dNative->BatchToSpaceND
7+
"""
8+
9+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
10+
11+
# pylint: disable=invalid-name,unused-argument,missing-docstring, unused-variable
12+
13+
14+
def rewrite_depthwise_conv_dilations(g, ops):
15+
pattern1 = \
16+
OpTypePattern("BatchToSpaceND", name="batch_to_space", inputs=[
17+
OpTypePattern("DepthwiseConv2dNative", name="depthwise_conv", inputs=[
18+
OpTypePattern("SpaceToBatchND", name="space_to_batch", inputs=[
19+
OpTypePattern("*"),
20+
OpTypePattern("Const|ConstV2"),
21+
OpTypePattern("Const|ConstV2"),
22+
]),
23+
OpTypePattern("*"),
24+
]),
25+
OpTypePattern("Const|ConstV2"),
26+
OpTypePattern("Const|ConstV2"),
27+
])
28+
29+
for pattern in [pattern1]:
30+
matcher = GraphMatcher(pattern, allow_reorder=False)
31+
match_results = list(matcher.match_ops(ops))
32+
for match_result in match_results:
33+
space_to_batch = match_result.get_op("space_to_batch")
34+
depthwise_conv = match_result.get_op("depthwise_conv")
35+
batch_to_space = match_result.get_op("batch_to_space")
36+
37+
block_shape1 = space_to_batch.inputs[1].get_tensor_value(as_list=True)
38+
paddings = space_to_batch.inputs[2].get_tensor_value(as_list=False).flatten().tolist()
39+
block_shape2 = batch_to_space.inputs[1].get_tensor_value(as_list=True)
40+
crops = batch_to_space.inputs[2].get_tensor_value(as_list=True)
41+
if block_shape1 != block_shape2:
42+
continue
43+
if depthwise_conv.get_attr_value("dilations", [1, 1, 1, 1]) != [1, 1, 1, 1]:
44+
continue
45+
if depthwise_conv.get_attr_value("strides", [1, 1, 1, 1]) != [1, 1, 1, 1]:
46+
continue
47+
if depthwise_conv.get_attr_value("data_format", b"NHWC") != b"NHWC":
48+
continue
49+
if depthwise_conv.get_attr_value("padding") != b"VALID":
50+
continue
51+
if crops != [[0, 0], [0, 0]]:
52+
continue
53+
54+
inp = space_to_batch.input[0]
55+
kernel = depthwise_conv.input[1]
56+
57+
g.replace_inputs(depthwise_conv, [inp, kernel])
58+
depthwise_conv.set_attr("dilations", [1] + block_shape1 + [1])
59+
depthwise_conv.set_attr("explicit_paddings", [0, 0] + paddings + [0, 0])
60+
depthwise_conv.set_attr("padding", "EXPLICIT")
61+
g.copy_shape(batch_to_space.output[0], depthwise_conv.output[0])
62+
g.replace_all_inputs(batch_to_space.output[0], depthwise_conv.output[0])
63+
64+
return g.get_nodes()

tf2onnx/tfonnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,7 @@ def compat_handler(ctx, node, **kwargs):
604604
rewrite_random_uniform_fold_const,
605605
rewrite_random_normal,
606606
rewrite_dropout,
607+
rewrite_depthwise_conv_dilations,
607608
rewrite_eye,
608609
rewrite_leakyrelu,
609610
rewrite_thresholded_relu,

0 commit comments

Comments
 (0)