Skip to content

Commit d8ad301

Browse files
authored
[Rewriter] Add optimizer to fold Pad operators into Conv (#2363)
Following (#2301), `fuse_pad_into_conv` rule set is introduced to reduce the following list of operators: - Conv ∘ Pad -> Conv - ConvInteger ∘ Pad -> ConvInteger Additionally, `NormalizePadFormat` is introduced in order to change `auto_pads` Conv attribute in its explicit `pads` list (ref: https://onnx.ai/onnx/operators/onnx__Conv.html).
1 parent ecb7677 commit d8ad301

File tree

3 files changed

+759
-0
lines changed

3 files changed

+759
-0
lines changed

onnxscript/rewriter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
broadcast_to_matmul,
2828
cast_constant_of_shape,
2929
collapse_slices,
30+
fuse_pad_into_conv,
3031
fuse_relus_clips,
3132
no_op,
3233
pattern,
@@ -49,6 +50,7 @@
4950
*fuse_relus_clips.fuse_relus_clips_rules().rules,
5051
*basic_rules.basic_optimization_rules().rules,
5152
*redundant_scatter_nd.rules.rules,
53+
*fuse_pad_into_conv.fuse_pad_into_conv_rule_set().rules,
5254
)
5355

5456

Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Fuses Pad nodes into preceding nodes. Supported fusion patterns:
4+
- Conv ∘ Pad -> Conv
5+
- ConvInteger ∘ Pad -> ConvInteger
6+
7+
To make some rules possible, we implicitly transform `auto_pad` attribute into its explicit list.
8+
"""
9+
10+
from __future__ import annotations
11+
12+
from typing import List, Sequence
13+
14+
import numpy as np
15+
import onnx_ir as ir
16+
17+
from onnxscript.rewriter import pattern as orp
18+
19+
20+
def fill_pads_with_axes(pads: Sequence[int], axes: Sequence[int], rank: int) -> List[int]:
21+
"""Converts the parameters of the ONNX Pad operator into an explicit list of values.
22+
23+
A filled list of pads will be returned following the format:
24+
[x1_begin, x2_begin, ..., x{rank}_begin, x1_end, x2_end, ..., x{rank}_end]
25+
26+
Args:
27+
pads: list of integers indicating the number of padding elements to add at
28+
the beginning and end of each axis.
29+
axes: list of axes that pads apply to.
30+
rank: value to compute the size of the filled list (2 * rank).
31+
32+
Returns:
33+
The filled list of pads.
34+
"""
35+
new_pads = [0] * 2 * rank
36+
N = len(axes)
37+
for start_idx, axis in enumerate(axes):
38+
new_pads[axis] = pads[start_idx]
39+
new_pads[axis + rank] = pads[start_idx + N]
40+
return new_pads
41+
42+
43+
def read_conv_attributes(ir_conv: ir.Node) -> dict[str, Sequence[int] | str]:
44+
# Read attributes
45+
attributes = {}
46+
ir_attributes = ir_conv.attributes
47+
attributes["kernel_shape"] = ir_attributes.get_ints(
48+
"kernel_shape", ir_conv.inputs[1].shape[2:]
49+
)
50+
attributes["strides"] = ir_attributes.get_ints(
51+
"strides", [1] * len(ir_conv.inputs[0].shape[2:])
52+
)
53+
attributes["auto_pad"] = ir_attributes.get_string("auto_pad", "NOTSET")
54+
if "pads" in ir_attributes:
55+
attributes["pads"] = ir_attributes.get_ints("pads")
56+
return attributes
57+
58+
59+
class _FuseConvPadBase(orp.RewriteRuleClassBase):
60+
"""Interface for PadConv nodes fusion."""
61+
62+
def __init__(self, as_function: bool = False):
63+
# Remove nodes is set to False to remove unused nodes after the rewrite, since
64+
# Pad or Conv inputs can come from constant nodes.
65+
# With remove_nodes=False these nodes are removed if these nodes are no longer needed.
66+
super().__init__(remove_nodes=False, as_function=as_function)
67+
68+
def rewrite(
69+
self, op: ir.tape.Tape, x: ir.Value, pad: ir.Value, conv: ir.Value
70+
) -> ir.Value:
71+
conv_node = conv.producer()
72+
73+
# Retrieve the padding and axes
74+
x_rank = len(x.shape)
75+
76+
# Get computed pads in check()
77+
pad_pads = self._pads_list
78+
79+
# Get only spatial pads
80+
new_pads = pad_pads[2:x_rank] + pad_pads[x_rank + 2 :]
81+
82+
# Replace conv pads = new + old
83+
conv_attr = conv_node.attributes.copy()
84+
if "pads" in conv_attr:
85+
new_pads = [x + y for x, y in zip(conv_attr["pads"].as_ints(), new_pads)]
86+
conv_attr.add(ir.AttrInt64s("pads", new_pads))
87+
88+
return op.op(
89+
conv_node.op_type,
90+
inputs=(x, *conv_node.inputs[1:]),
91+
attributes=conv_attr,
92+
domain=conv_node.domain,
93+
name=conv_node.name,
94+
)
95+
96+
def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.MatchResult:
97+
"""Condition to check if we need to replace the pattern.
98+
99+
If Pad inputs can be added in 'pads' attribute of the Conv operator.
100+
101+
To validate this, we need to check the following:
102+
1. `Pad<mode>` attribute has 'constant' as value
103+
2. `Pad` operator inputs are constants ('pads', 'constant_value', 'axes')
104+
3. 'constant_value' is equal to 0.0.
105+
4. `Pad` operator is only used for the spatial dimensions (batch dimension and channels
106+
remain unchanged).
107+
108+
If the above are true, then we don't need the reshapes.
109+
110+
Returns:
111+
True if we need to replace the pattern, False otherwise.
112+
"""
113+
del context # Unused
114+
check_result = orp.MatchResult()
115+
pad_node = pad.producer()
116+
if x.shape is None:
117+
return check_result.fail(
118+
f"Input shapes are not defined on {pad_node.name} ({pad_node.op_type})."
119+
)
120+
x_rank = len(x.shape)
121+
122+
# Pad constraints: attributes
123+
if (mode := pad_node.attributes.get("mode", None)) and mode.as_string() != "constant":
124+
return check_result.fail(
125+
f"{pad_node.name} ({pad_node.op_type}) mode must be 'constant'."
126+
)
127+
128+
# Pad constraints: inputs
129+
if (pads := pad_node.inputs[1]).const_value is None:
130+
return check_result.fail(f"{pads.name} is not a constant/initializer.")
131+
if len(pad_node.inputs) > 2 and (constant_value := pad_node.inputs[2]) is not None:
132+
if constant_value.const_value is None:
133+
return check_result.fail(
134+
f"{constant_value.name} is not a constant/initializer."
135+
)
136+
elif constant_value.const_value.numpy().item() != 0:
137+
return check_result.fail(f"{constant_value.name} must be equal to 0.")
138+
if len(pad_node.inputs) > 3 and (axes := pad_node.inputs[3]) is not None:
139+
if axes.const_value is None:
140+
return check_result.fail(f"{axes.name} is not a constant/initializer.")
141+
axes_list = [x if x >= 0 else x_rank + x for x in axes.const_value.numpy()]
142+
else:
143+
axes_list = list(range(x_rank))
144+
145+
# Pad constraints: values
146+
self._pads_list = fill_pads_with_axes(pads.const_value.numpy(), axes_list, x_rank)
147+
if np.any(self._pads_list[:2] + self._pads_list[x_rank : x_rank + 2]):
148+
self._pads_list = None
149+
return check_result.fail(f"{pads.name} must be zero in non-spatial dimensions.")
150+
151+
return check_result
152+
153+
154+
class FuseConvPad(_FuseConvPadBase):
155+
"""Replaces ``Conv(Pad(x))`` with ``Conv(x)``."""
156+
157+
def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
158+
return op.Conv(
159+
op.Pad(x, _allow_other_inputs=True, _outputs=["pad"]),
160+
_allow_other_inputs=True,
161+
_outputs=["conv"],
162+
)
163+
164+
def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.MatchResult:
165+
check_result = super().check(context, x, pad, conv)
166+
if not check_result:
167+
return check_result
168+
169+
# Conv constraints: attributes
170+
conv_node = conv.producer()
171+
if conv_node.attributes.get_string("auto_pad", "NOTSET") != "NOTSET":
172+
return check_result.fail(
173+
f"{conv_node.name} ({conv_node.op_type}) auto_pad must be 'NOTSET'."
174+
)
175+
return check_result
176+
177+
178+
class FuseConvIntegerPad(FuseConvPad):
179+
"""Replaces ``ConvInteger(Pad(x))`` with ``ConvInteger(x)``."""
180+
181+
def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
182+
return op.ConvInteger(
183+
op.Pad(x, _allow_other_inputs=True, _outputs=["pad"]),
184+
_allow_other_inputs=True,
185+
_outputs=["conv"],
186+
)
187+
188+
189+
class _NormalizePadFormatBase(orp.RewriteRuleClassBase):
190+
"""Interface to normalize pad attributes in conv nodes."""
191+
192+
@staticmethod
193+
def compute_pads(
194+
input_shape: Sequence[int],
195+
output_shape: Sequence[int],
196+
attributes: dict[str, Sequence[int] | str],
197+
) -> Sequence[int]:
198+
raise NotImplementedError("Child have to implement this function")
199+
200+
def rewrite(self, op: ir.tape.Tape, conv: ir.Value, **__) -> ir.Value:
201+
conv_node = conv.producer()
202+
203+
# Read spatial dimensions and attributes
204+
input_shape = conv_node.inputs[0].shape[2:]
205+
output_shape = conv_node.outputs[0].shape[2:]
206+
attributes = read_conv_attributes(conv_node)
207+
208+
# Convert auto_pad mode into an explicit list
209+
pads = self.compute_pads(input_shape, output_shape, attributes)
210+
211+
# Replace auto_pad, forcing to the explicit list
212+
conv_attr = conv_node.attributes.copy()
213+
conv_attr.add(ir.AttrString("auto_pad", "NOTSET"))
214+
if any(x != 0 for x in pads):
215+
conv_attr.add(ir.AttrInt64s("pads", pads))
216+
217+
return op.op(
218+
conv_node.op_type,
219+
inputs=conv_node.inputs,
220+
attributes=conv_attr,
221+
domain=conv_node.domain,
222+
name=conv_node.name,
223+
)
224+
225+
def check(self, context, conv: ir.Value, **__) -> orp.MatchResult:
226+
"""Condition to check if we need to replace the pattern.
227+
228+
If it is possible to deduce 'pads'.
229+
230+
To validate this, we need to check the following:
231+
1. `Conv<auto_pad != "NOTSET">` (nothing to do in this case, since 'pads' are
232+
already explicit)
233+
2. it is possible to deduce the input rank when `Conv<auto_pad == "VALID">`
234+
3. When `Conv<auto_pad != "VALID">`:
235+
* spatial input/output shapes are static
236+
* it is possible to infer `kernel_shape` either from the `Conv` operator attribute
237+
or from the kernel input
238+
239+
If the above are true, then we don't need the reshapes.
240+
241+
Returns:
242+
True if we need to replace the pattern, False otherwise.
243+
"""
244+
del context
245+
check_result = orp.MatchResult()
246+
247+
# Conv constraints: attributes
248+
conv_node = conv.producer()
249+
auto_pad = conv_node.attributes.get_string("auto_pad", None)
250+
if auto_pad in {None, "NOTSET"}:
251+
return check_result.fail(
252+
f"{conv_node.name} ({conv_node.op_type}) auto_pad must be different to 'NOTSET'."
253+
)
254+
255+
# Conv constraints: inputs/outputs
256+
input_shape = conv_node.inputs[0].shape
257+
output_shape = conv_node.outputs[0].shape
258+
if input_shape is None or len(input_shape) <= 2:
259+
return check_result.fail(
260+
f"Input shapes are not defined on {conv_node.name} ({conv_node.op_type})."
261+
)
262+
if output_shape is None or len(output_shape) <= 2:
263+
return check_result.fail(
264+
f"Output shapes are not defined on {conv_node.name} ({conv_node.op_type})."
265+
)
266+
267+
# Conv constraints: values
268+
if auto_pad != "VALID":
269+
error_msg = (
270+
"Expected static spatial {} shapes on "
271+
+ conv_node.name
272+
+ f" ({conv_node.op_type})."
273+
)
274+
if not all(isinstance(x, int) for x in input_shape[2:]):
275+
return check_result.fail(error_msg.format("input"))
276+
if not all(isinstance(x, int) for x in output_shape[2:]):
277+
return check_result.fail(error_msg.format("output"))
278+
attributes = read_conv_attributes(conv_node)
279+
if len(attributes["kernel_shape"]) != len(attributes["strides"]):
280+
return check_result.fail(
281+
"strides must have the same length than kernel_shape on "
282+
f"{conv_node.name} ({conv_node.op_type})."
283+
)
284+
return check_result
285+
286+
287+
class NormalizePadFormatConv(_NormalizePadFormatBase):
288+
"""Convert auto_pad attribute into 'NOTSET' in Conv nodes ."""
289+
290+
@staticmethod
291+
def compute_pads(
292+
input_shape: Sequence[int],
293+
output_shape: Sequence[int],
294+
attributes: dict[str, Sequence[int] | str],
295+
) -> Sequence[int]:
296+
# Compute pads, following auto_pad/pads attributes
297+
if attributes["auto_pad"] in {"NOTSET", "VALID"}:
298+
assert len(input_shape) > 0
299+
return attributes.get("pads", [0] * len(input_shape) * 2)
300+
301+
bottom_pads, top_pads = [], []
302+
kernel_shape, strides = attributes["kernel_shape"], attributes["strides"]
303+
assert len(kernel_shape) == len(strides) == len(input_shape) == len(output_shape)
304+
for x, y, k, s in zip(input_shape, output_shape, kernel_shape, strides):
305+
# Compute the output shape and the total padding to apply
306+
total_pads = max(0, (y - 1) * s + k - x)
307+
308+
# Depending of mode, apply the padding to the upper or lower part
309+
pad1 = total_pads // 2
310+
pad2 = total_pads - pad1
311+
if attributes["auto_pad"] == "SAME_UPPER":
312+
bottom_pads.append(pad1)
313+
top_pads.append(pad2)
314+
else:
315+
top_pads.append(pad1)
316+
bottom_pads.append(pad2)
317+
return bottom_pads + top_pads
318+
319+
def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
320+
return op.Conv(x, _allow_other_inputs=True, _outputs=["conv"])
321+
322+
323+
class NormalizePadFormatConvInteger(NormalizePadFormatConv):
324+
"""Convert auto_pad attribute into 'NOTSET' in ConvInteger nodes ."""
325+
326+
def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
327+
return op.ConvInteger(x, _allow_other_inputs=True, _outputs=["conv"])
328+
329+
330+
normalize_pad_format_conv = NormalizePadFormatConv.rule()
331+
normalize_pad_format_conv_integer = NormalizePadFormatConvInteger.rule()
332+
fuse_pad_into_conv = FuseConvPad.rule()
333+
fuse_pad_into_conv_integer = FuseConvIntegerPad.rule()
334+
335+
336+
def fuse_pad_into_conv_rule_set() -> orp.RewriteRuleSet:
337+
"""Returns a set of rewrite rules that fuse Pad nodes into preceding:
338+
- Conv
339+
- ConvInteger
340+
341+
Returns:
342+
RewriteRuleSet
343+
"""
344+
return orp.RewriteRuleSet(
345+
[
346+
normalize_pad_format_conv,
347+
normalize_pad_format_conv_integer,
348+
fuse_pad_into_conv,
349+
fuse_pad_into_conv_integer,
350+
]
351+
)

0 commit comments

Comments
 (0)