1818
1919
2020def fill_pads_with_axes (pads : Sequence [int ], axes : Sequence [int ], rank : int ) -> List [int ]:
21+ """Converts the parameters of the ONNX Pad operator into an explicit list of values.
22+
23+ A filled list of pads will be returned following the format:
24+ [x1_begin, x2_begin, ..., x{rank}_begin, x1_end, x2_end, ..., x{rank}_end]
25+
26+ Args:
27+ pads: list of integers indicating the number of padding elements to add at
28+ the beginning and end of each axis.
29+ axes: list of axes that pads apply to.
30+ rank: value to compute the size of the filled list (2 * rank).
31+
32+ Returns:
33+ The filled list of pads.
34+ """
2135 new_pads = [0 ] * 2 * rank
2236 N = len (axes )
2337 for start_idx , axis in enumerate (axes ):
@@ -42,11 +56,13 @@ def read_conv_attributes(ir_conv: ir.Node) -> dict[str, Sequence[int] | str]:
4256 return attributes
4357
4458
45- class _FusePadConvBase (orp .RewriteRuleClassBase ):
59+ class _FuseConvPadBase (orp .RewriteRuleClassBase ):
4660 """Interface for PadConv nodes fusion."""
4761
4862 def __init__ (self , as_function : bool = False ):
49- # Remove nodes is set to False to remove unused nodes after the rewrite.
63+ # Remove nodes is set to False to remove unused nodes after the rewrite, since
64+ # Pad or Conv inputs can come from constant nodes.
65+ # With remove_nodes=False these nodes are removed if these nodes are no longer needed.
5066 super ().__init__ (remove_nodes = False , as_function = as_function )
5167
5268 def rewrite (
@@ -84,14 +100,32 @@ def rewrite(
84100 )
85101
86102 def check (self , context , x : ir .Value , pad : ir .Value , conv : ir .Value ) -> orp .MatchResult :
103+ """Condition to check if we need to replace the pattern.
104+
105+ If Pad inputs can be added in 'pads' attribute of the Conv operator.
106+
107+ To validate this, we need to check the following:
108+ 1. `Pad<mode>` attribute has 'constant' as value
109+ 2. `Pad` operator inputs are constants ('pads', 'constant_value', 'axes')
110+ 3. 'constant_value' is equal to 0.0.
111+ 4. `Pad` operator is only used for the spatial dimensions (batch dimension and channels
112+ remain unchanged).
113+
114+ If the above are true, then we don't need the reshapes.
115+
116+ Returns:
117+ True if we need to replace the pattern, False otherwise.
118+ """
87119 del context # Unused
88120 check_result = orp .MatchResult ()
89121 pad_node = pad .producer ()
90122 x_rank = len (x .shape )
91123
92124 # Pad constraints: attributes
93125 if (mode := pad_node .attributes .get ("mode" , None )) and mode .as_string () != "constant" :
94- return check_result .fail (f"{ pad_node .name } mode must be 'constant'." )
126+ return check_result .fail (
127+ f"{ pad_node .name } ({ pad_node .op_type } ) mode must be 'constant'."
128+ )
95129
96130 # Pad constraints: inputs
97131 if (pads := pad_node .inputs [1 ]).const_value is None :
@@ -118,8 +152,8 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
118152 return check_result
119153
120154
121- class FusePadConv ( _FusePadConvBase ):
122- """Replaces ``Pad( Conv(x))`` with ``Conv(x)``."""
155+ class FuseConvPad ( _FuseConvPadBase ):
156+ """Replaces ``Conv(Pad (x))`` with ``Conv(x)``."""
123157
124158 def pattern (self , op : ir .tape .Tape , x : ir .Value ) -> ir .Value :
125159 return op .Conv (
@@ -138,12 +172,14 @@ def check(self, context, x: ir.Value, pad: ir.Value, conv: ir.Value) -> orp.Matc
138172 if (
139173 apad := conv_node .attributes .get ("auto_pad" , None )
140174 ) and apad .as_string () != "NOTSET" :
141- return check_result .fail (f"{ conv_node .name } auto_pad must be 'NOTSET'." )
175+ return check_result .fail (
176+ f"{ conv_node .name } ({ conv_node .op_type } ) auto_pad must be 'NOTSET'."
177+ )
142178 return check_result
143179
144180
145- class FusePadConvInteger ( FusePadConv ):
146- """Replaces ``Pad( ConvInteger(x))`` with ``ConvInteger(x)``."""
181+ class FuseConvIntegerPad ( FuseConvPad ):
182+ """Replaces ``ConvInteger(Pad (x))`` with ``ConvInteger(x)``."""
147183
148184 def pattern (self , op : ir .tape .Tape , x : ir .Value ) -> ir .Value :
149185 return op .ConvInteger (
@@ -190,36 +226,63 @@ def rewrite(self, op: ir.tape.Tape, conv: ir.Value, **__) -> ir.Value:
190226 )
191227
192228 def check (self , context , conv : ir .Value , ** __ ) -> orp .MatchResult :
229+ """Condition to check if we need to replace the pattern.
230+
231+ If it is possible to deduce 'pads'.
232+
233+ To validate this, we need to check the following:
234+ 1. `Conv<auto_pad != "NOTSET">` (nothing to do in this case, since 'pads' are
235+ already explicit)
236+ 2. it is possible to deduce the input rank when `Conv<auto_pad == "VALID">`
237+ 3. When `Conv<auto_pad != "VALID">`:
238+ * spatial input/output shapes are static
239+ * it is possible to infer `kernel_shape` either from the `Conv` operator attribute
240+ or from the kernel input
241+
242+ If the above are true, then we don't need the reshapes.
243+
244+ Returns:
245+ True if we need to replace the pattern, False otherwise.
246+ """
193247 del context
194248 check_result = orp .MatchResult ()
195249
196250 # Conv constraints: attributes
197251 conv_node = conv .producer ()
198252 auto_pad = conv_node .attributes .get_string ("auto_pad" , None )
199- if auto_pad in [ None , "NOTSET" ] :
253+ if auto_pad in { None , "NOTSET" } :
200254 return check_result .fail (
201- f"{ conv_node .name } auto_pad must be different to 'NOTSET'."
255+ f"{ conv_node .name } ( { conv_node . op_type } ) auto_pad must be different to 'NOTSET'."
202256 )
203257
204258 # Conv constraints: inputs/outputs
205259 input_shape = conv_node .inputs [0 ].shape
206260 output_shape = conv_node .outputs [0 ].shape
207261 if len (input_shape ) <= 2 :
208- return check_result .fail (f"Input shapes are not defined on { conv_node .name } ." )
262+ return check_result .fail (
263+ f"Input shapes are not defined on { conv_node .name } ({ conv_node .op_type } )."
264+ )
209265 if len (output_shape ) <= 2 :
210- return check_result .fail (f"Output shapes are not defined on { conv_node .name } ." )
266+ return check_result .fail (
267+ f"Output shapes are not defined on { conv_node .name } ({ conv_node .op_type } )."
268+ )
211269
212270 # Conv constraints: values
213271 if auto_pad != "VALID" :
214- error_msg = "Expected static spatial {} shapes on " + conv_node .name + "."
272+ error_msg = (
273+ "Expected static spatial {} shapes on "
274+ + conv_node .name
275+ + f" ({ conv_node .op_type } )."
276+ )
215277 if not all (isinstance (x , int ) for x in input_shape [2 :]):
216278 return check_result .fail (error_msg .format ("input" ))
217279 if not all (isinstance (x , int ) for x in output_shape [2 :]):
218280 return check_result .fail (error_msg .format ("output" ))
219281 attributes = read_conv_attributes (conv_node )
220282 if len (attributes ["kernel_shape" ]) != len (attributes ["strides" ]):
221283 return check_result .fail (
222- f"strides must have the same length than kernel_shape on { conv_node .name } ."
284+ "strides must have the same length than kernel_shape on "
285+ f"{ conv_node .name } ({ conv_node .op_type } )."
223286 )
224287 return check_result
225288
@@ -234,7 +297,7 @@ def compute_pads(
234297 attributes : dict [str , Sequence [int ] | str ],
235298 ) -> Sequence [int ]:
236299 # Compute pads, following auto_pad/pads attributes
237- if attributes ["auto_pad" ] in [ "NOTSET" , "VALID" ] :
300+ if attributes ["auto_pad" ] in { "NOTSET" , "VALID" } :
238301 assert len (input_shape ) > 0
239302 return attributes .get ("pads" , [0 ] * len (input_shape ) * 2 )
240303
@@ -269,8 +332,8 @@ def pattern(self, op: ir.tape.Tape, x: ir.Value) -> ir.Value:
269332
270333normalize_pad_format_conv = NormalizePadFormatConv .rule ()
271334normalize_pad_format_conv_integer = NormalizePadFormatConvInteger .rule ()
272- fuse_pad_into_conv = FusePadConv .rule ()
273- fuse_pad_into_conv_integer = FusePadConvInteger .rule ()
335+ fuse_pad_into_conv = FuseConvPad .rule ()
336+ fuse_pad_into_conv_integer = FuseConvIntegerPad .rule ()
274337
275338
276339def fuse_pad_into_conv_rule_set () -> orp .RewriteRuleSet :
0 commit comments