Skip to content

Commit 596f237

Browse files
Make conv dilations rewriter work for dynamic pads (#1470)
Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent 96a040b commit 596f237

File tree

2 files changed

+79
-21
lines changed

2 files changed

+79
-21
lines changed

tests/test_backend.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -764,6 +764,22 @@ def func(x):
764764
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
765765
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
766766

767+
@check_tf_min_version("2.0")
768+
@skip_tflite("TFlite adds ops that obscure pattern")
769+
@allow_missing_shapes("Rewriting makes some shapes known")
770+
def test_conv2d_dilations_rewriter_unknown_shape(self):
771+
x_shape = [2, 32, 16, 3]
772+
x_val = make_xval(x_shape)
773+
def func():
774+
x = tf_placeholder(tf.float32, [2, None, None, 3], name=_TFINPUT)
775+
t = tf.keras.layers.Conv2D(filters=768, kernel_size=3, dilation_rate=3, padding="VALID")
776+
t.build(x_shape)
777+
y = t.call(x)
778+
return tf.identity(y, name=_TFOUTPUT)
779+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2,
780+
as_session=True, premade_placeholders=True,
781+
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
782+
767783
@check_tf_min_version("2.0")
768784
def test_conv3d_dilations_rewriter(self):
769785
x_shape = [2, 32, 16, 8, 3]
@@ -788,6 +804,18 @@ def func(x):
788804
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2, as_session=True,
789805
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
790806

807+
@skip_tf2("Uses tf.layers")
808+
def test_conv1d_tf1_dilations_rewriter_unknown_shape(self):
809+
x_shape = [2, 32, 3]
810+
x_val = make_xval(x_shape)
811+
def func():
812+
x = tf_placeholder(tf.float32, [2, None, 3], name=_TFINPUT)
813+
y = tf.layers.conv1d(x, filters=768, kernel_size=3, dilation_rate=3, padding="VALID", name="conv1")
814+
return tf.identity(y, name=_TFOUTPUT)
815+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04, atol=1e-2,
816+
as_session=True, premade_placeholders=True,
817+
graph_validator=lambda g: check_op_count(g, "Reshape", 0, disabled=False))
818+
791819
def test_lrn_default(self):
792820
x_shape = [1, 3, 4, 3]
793821
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)

tf2onnx/rewriter/conv_dilations_rewriter.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
pat = SpaceToBatchND->DepthwiseConv2dNative->BatchToSpaceND
77
"""
88

9+
import numpy as np
910
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
1011

1112
# pylint: disable=invalid-name,unused-argument,missing-docstring, unused-variable
@@ -18,30 +19,30 @@ def rewrite_conv_dilations(g, ops):
1819
OpTypePattern("SpaceToBatchND", name="space_to_batch", inputs=[
1920
OpTypePattern("*"),
2021
OpTypePattern("Const|ConstV2"),
21-
OpTypePattern("Const|ConstV2"),
22+
OpTypePattern("*"),
2223
]),
2324
OpTypePattern("*"),
2425
]),
2526
OpTypePattern("Const|ConstV2"),
26-
OpTypePattern("Const|ConstV2"),
27+
OpTypePattern("*"),
2728
])
2829
pattern2 = \
2930
OpTypePattern("BatchToSpaceND", name="batch_to_space", inputs=[
3031
OpTypePattern("Squeeze", name="squeeze", inputs=[
31-
OpTypePattern("DepthwiseConv2dNative|Conv2D|Conv3D", name="conv", inputs=[
32+
OpTypePattern("Conv2D", name="conv", inputs=[
3233
OpTypePattern("ExpandDims", name="expand", inputs=[
3334
OpTypePattern("SpaceToBatchND", name="space_to_batch", inputs=[
3435
OpTypePattern("*"),
3536
OpTypePattern("Const|ConstV2"),
36-
OpTypePattern("Const|ConstV2"),
37+
OpTypePattern("*"),
3738
]),
3839
OpTypePattern("Const|ConstV2"),
3940
]),
4041
OpTypePattern("*"),
4142
]),
4243
]),
4344
OpTypePattern("Const|ConstV2"),
44-
OpTypePattern("Const|ConstV2"),
45+
OpTypePattern("*"),
4546
])
4647

4748
for pattern in [pattern1, pattern2]:
@@ -61,9 +62,7 @@ def rewrite_conv_dilations(g, ops):
6162
continue
6263

6364
block_shape1 = space_to_batch.inputs[1].get_tensor_value(as_list=True)
64-
paddings = space_to_batch.inputs[2].get_tensor_value(as_list=True)
6565
block_shape2 = batch_to_space.inputs[1].get_tensor_value(as_list=True)
66-
crops = batch_to_space.inputs[2].get_tensor_value(as_list=True)
6766

6867
if block_shape1 != block_shape2:
6968
continue
@@ -79,15 +78,27 @@ def rewrite_conv_dilations(g, ops):
7978
if conv.get_attr_value("padding") != b"VALID":
8079
continue
8180

82-
83-
base_start_pad = [p[0] for p in paddings]
84-
if any(c[0] != 0 for c in crops):
85-
continue
86-
base_end_pad = [p[1] - c[1] for p, c in zip(paddings, crops)]
87-
if not all(0 <= p[1] - bp < bs for p, bp, bs in zip(paddings, base_end_pad, block_shape1)):
88-
continue
81+
if space_to_batch.inputs[2].is_const() and batch_to_space.inputs[2].is_const():
82+
paddings = space_to_batch.inputs[2].get_tensor_value(as_list=True)
83+
crops = batch_to_space.inputs[2].get_tensor_value(as_list=True)
84+
base_start_pad = [p[0] for p in paddings]
85+
if any(c[0] != 0 for c in crops):
86+
continue
87+
base_end_pad = [p[1] - c[1] for p, c in zip(paddings, crops)]
88+
if not all(0 <= p[1] - bp < bs for p, bp, bs in zip(paddings, base_end_pad, block_shape1)):
89+
continue
90+
pad_mode = "EXPLICIT"
91+
if is_conv_1d:
92+
base_start_pad = [0] + base_start_pad
93+
base_end_pad = [0] + base_end_pad
94+
base_pad_flat = [0, 0] + [x for s, e in zip(base_start_pad, base_end_pad) for x in [s, e]] + [0, 0]
95+
else:
96+
pad_mode = determine_pad_mode(space_to_batch.inputs[2])
97+
if pad_mode is None:
98+
continue
8999

90100
if is_conv_1d:
101+
block_shape1 = [1] + block_shape1
91102
inp = space_to_batch.input[0]
92103
g.replace_inputs(expand, [inp, expand.input[1]])
93104
g.copy_shape(batch_to_space.output[0], squeeze.output[0])
@@ -96,20 +107,39 @@ def rewrite_conv_dilations(g, ops):
96107
g.set_shape(squeeze.input[0], squeeze_out_shape[:1] + [1] + squeeze_out_shape[1:])
97108
expand_inp_shape = g.get_shape(expand.input[0])
98109
g.set_shape(expand.output[0], expand_inp_shape[:1] + [1] + expand_inp_shape[1:])
99-
100-
base_start_pad = [0] + base_start_pad
101-
base_end_pad = [0] + base_end_pad
102-
block_shape1 = [1] + block_shape1
103110
else:
104111
inp = space_to_batch.input[0]
105112
kernel = conv.input[1]
106113
g.replace_inputs(conv, [inp, kernel])
107114
g.copy_shape(batch_to_space.output[0], conv.output[0])
108115
g.replace_all_inputs(batch_to_space.output[0], conv.output[0])
109116

110-
base_pad_flat = [0, 0] + [x for s, e in zip(base_start_pad, base_end_pad) for x in [s, e]] + [0, 0]
111117
conv.set_attr("dilations", [1] + block_shape1 + [1])
112-
conv.set_attr("explicit_paddings", base_pad_flat)
113-
conv.set_attr("padding", "EXPLICIT")
118+
conv.set_attr("padding", pad_mode)
119+
if pad_mode == "EXPLICIT":
120+
conv.set_attr("explicit_paddings", base_pad_flat)
114121

115122
return g.get_nodes()
123+
124+
def determine_pad_mode(paddings_inp_node):
125+
tensor_ops = set(["Concat", "ConcatV2", "ConcatV3", "StridedSlice", "Pack", "ExpandDims", "Identity"])
126+
while paddings_inp_node.type in tensor_ops:
127+
non_const = [inp for inp in paddings_inp_node.inputs if not inp.is_const()]
128+
if len(non_const) == 0:
129+
return None
130+
paddings_inp_node = non_const[0]
131+
if paddings_inp_node.type == "FloorMod":
132+
return "VALID"
133+
if paddings_inp_node.type in ["Add", "AddV2"]:
134+
if paddings_inp_node.inputs[0].type == "FloorMod":
135+
pad = paddings_inp_node.inputs[1]
136+
elif paddings_inp_node.inputs[1].type == "FloorMod":
137+
pad = paddings_inp_node.inputs[0]
138+
else:
139+
return None
140+
if pad.is_const():
141+
if np.any(pad.get_tensor_value(as_list=False)):
142+
#return "SAME" ORT doesn't implement dilations for SAME autopadding yet
143+
return None
144+
return "VALID"
145+
return None

0 commit comments

Comments
 (0)