66import numpy as np
77import onnx_ir as ir
88import parameterized
9- from onnx_ir .passes .common import onnx_checker
9+ from onnx_ir .passes .common import onnx_checker , shape_inference
1010
1111from onnxscript .rewriter import pattern as orp
1212from onnxscript .rewriter import testing
1313from onnxscript .rewriter .fuse_pad_into_conv import (
1414 fuse_pad_into_conv ,
1515 fuse_pad_into_conv_rule_set ,
16+ normalize_pad_format_conv ,
1617)
1718
1819
@@ -83,22 +84,24 @@ def build_model(
8384 ir_version = 9 ,
8485 )
8586 onnx_checker .CheckerPass (True )(ir_model )
87+ ir_model = shape_inference .infer_shapes (ir_model )
8688 return ir_model
8789
8890
8991class FusePadConvTest (FusePadConvBaseTest ):
9092 @parameterized .parameterized .expand (
9193 [
92- (pad_pads , const_value , axes , conv_pads )
93- for pad_pads , axes , conv_pads in [
94- ([0 , 0 , 2 , 2 , 0 , 0 , 2 , 2 ], None , None ),
95- ([0 , 2 , 2 , 0 , 2 , 2 ], ir .tensor ([1 , - 2 , - 1 ], name = "axes" ), [2 , 0 , 2 , 0 ]),
96- ([1 , 1 , 1 , 1 ], ir .tensor ([- 2 , 3 ], name = "axes" ), [0 , 1 , 0 , 1 ]),
94+ (pad_pads , const_value , axes , conv_pads , conv_auto_pad )
95+ for pad_pads , axes , conv_pads , conv_auto_pad in [
96+ ([0 , 0 , 2 , 2 , 0 , 0 , 2 , 2 ], None , None , None ),
97+ ([0 , 2 , 2 , 0 , 2 , 2 ], ir .tensor ([1 , - 2 , - 1 ], name = "axes" ), [2 , 0 , 2 , 0 ], None ),
98+ ([1 , 1 , 1 , 1 ], ir .tensor ([- 2 , 3 ], name = "axes" ), [0 , 1 , 0 , 1 ], None ),
99+ ([1 , 3 , 1 , 3 ], ir .tensor ([3 , 2 ], name = "axes" ), None , "VALID" ),
97100 ]
98101 for const_value in [None , 0.0 ]
99102 ]
100103 )
101- def test_fuse_pad_into_conv (self , pad_pads , const_value , axes , conv_pads ):
104+ def test_fuse_pad_into_conv (self , pad_pads , const_value , axes , conv_pads , conv_auto_pad ):
102105 pad_inputs = [ir .tensor (pad_pads , name = "pads" )]
103106 if const_value is not None or axes is not None :
104107 pad_inputs .append (const_value )
@@ -109,15 +112,15 @@ def test_fuse_pad_into_conv(self, pad_pads, const_value, axes, conv_pads):
109112 input_shape = ir .Shape (("N" , 32 , 14 , 16 )),
110113 weight_shape = (10 , 32 , 3 , 3 ),
111114 pad_inputs = pad_inputs ,
112- conv_attributes = {"pads" : conv_pads },
115+ conv_attributes = {"pads" : conv_pads , "auto_pad" : conv_auto_pad },
113116 )
114117 updated_model = _clone_model (base_model )
115118
116119 # Apply rule
117120 count = fuse_pad_into_conv_rule_set ().apply_to_model (updated_model )
118121
119122 # Check that Pad was fused
120- self .assertEqual (count , 1 )
123+ self .assertEqual (count , 1 if conv_auto_pad is None else 2 )
121124 self .assertEqual (updated_model .graph .num_nodes (), 1 )
122125 onnx_checker .CheckerPass (True )(updated_model )
123126
@@ -223,16 +226,19 @@ def get_conv_weights(self, shape: typing.Sequence[int], tape: ir.tape.Tape = Non
223226
224227 @parameterized .parameterized .expand (
225228 [
226- (pad_pads , const_value , axes , conv_pads )
227- for pad_pads , axes , conv_pads in [
228- ([0 , 0 , 3 , 2 , 0 , 0 , 1 , 4 ], None , [1 , 1 , 1 , 1 ]),
229- ([2 , 2 , 0 , 2 , 2 , 0 ], ir .tensor ([- 2 , - 1 , 1 ], name = "axes" ), None ),
230- ([1 , 2 , 2 , 1 ], ir .tensor ([- 1 , 2 ], name = "axes" ), [0 , 1 , 0 , 1 ]),
229+ (pad_pads , const_value , axes , conv_pads , conv_auto_pad )
230+ for pad_pads , axes , conv_pads , conv_auto_pad in [
231+ ([0 , 0 , 3 , 2 , 0 , 0 , 1 , 4 ], None , [1 , 1 , 1 , 1 ], None ),
232+ ([2 , 2 , 0 , 2 , 2 , 0 ], ir .tensor ([- 2 , - 1 , 1 ], name = "axes" ), None , None ),
233+ ([1 , 2 , 2 , 1 ], ir .tensor ([- 1 , 2 ], name = "axes" ), [0 , 1 , 0 , 1 ], None ),
234+ ([3 , 3 ], ir .tensor ([2 ], name = "axes" ), None , "SAME_UPPER" ),
231235 ]
232236 for const_value in [None , ir .tensor (np .array ([0 ], "uint8" ), name = "const_value" )]
233237 ]
234238 )
235- def test_fuse_pad_into_conv_integer (self , pad_pads , const_value , axes , conv_pads ):
239+ def test_fuse_pad_into_conv_integer (
240+ self , pad_pads , const_value , axes , conv_pads , conv_auto_pad
241+ ):
236242 pad_inputs = [ir .tensor (pad_pads , name = "pads" )]
237243 if const_value is not None or axes is not None :
238244 pad_inputs .append (const_value )
@@ -243,15 +249,15 @@ def test_fuse_pad_into_conv_integer(self, pad_pads, const_value, axes, conv_pads
243249 input_shape = ir .Shape (("N" , 24 , 19 , 23 )),
244250 weight_shape = (8 , 24 , 3 , 3 ),
245251 pad_inputs = pad_inputs ,
246- conv_attributes = {"pads" : conv_pads },
252+ conv_attributes = {"pads" : conv_pads , "auto_pad" : conv_auto_pad },
247253 )
248254 updated_model = _clone_model (base_model )
249255
250256 # Apply rule
251257 count = fuse_pad_into_conv_rule_set ().apply_to_model (updated_model )
252258
253259 # Check that Pad was fused
254- self .assertEqual (count , 1 )
260+ self .assertEqual (count , 1 if conv_auto_pad is None else 2 )
255261 self .assertEqual (updated_model .graph .num_nodes (), 1 )
256262 onnx_checker .CheckerPass (True )(updated_model )
257263
@@ -260,5 +266,67 @@ def test_fuse_pad_into_conv_integer(self, pad_pads, const_value, axes, conv_pads
260266 testing .assert_numerically_equal (base_model , updated_model , (inputs ,), atol = 0 , rtol = 0 )
261267
262268
269+ class NormalizePadFormatTest (FusePadConvBaseTest ):
270+ @parameterized .parameterized .expand (
271+ [
272+ (strides , kernel_shape , auto_pad )
273+ for strides , kernel_shape in [((2 , 3 ), (1 , 4 )), ((2 , 1 ), (5 , 2 ))]
274+ for auto_pad in ["SAME_UPPER" , "SAME_LOWER" , "VALID" ]
275+ ]
276+ )
277+ def test_normalize_pad_format (self , strides , kernel_shape , auto_pad ):
278+ pad_inputs = [
279+ ir .tensor ([1 , 1 , 1 , 1 ], name = "pads" ),
280+ None ,
281+ ir .tensor ([2 , 3 ], name = "axes" ),
282+ ]
283+ base_model = self .build_model (
284+ op_type = "Conv" ,
285+ input_shape = ir .Shape (("N" , 32 , 22 , 27 )),
286+ weight_shape = (32 , 32 , * kernel_shape ),
287+ pad_inputs = pad_inputs ,
288+ conv_attributes = {
289+ "strides" : strides ,
290+ "auto_pad" : auto_pad ,
291+ "kernel_shape" : kernel_shape ,
292+ },
293+ )
294+ updated_model = _clone_model (base_model )
295+
296+ # Apply rule
297+ count = fuse_pad_into_conv_rule_set ().apply_to_model (updated_model )
298+
299+ # Check that Pad was fused
300+ self .assertEqual (count , 2 )
301+ self .assertEqual (updated_model .graph .num_nodes (), 1 )
302+ onnx_checker .CheckerPass (True )(updated_model )
303+
304+ # Check inference
305+ inputs = self .rng .random ((1 , 32 , 22 , 27 ), dtype = "float32" )
306+ testing .assert_numerically_equal (base_model , updated_model , (inputs ,), atol = 0 , rtol = 0 )
307+
308+ def test_unsupported_normalize_pad_format (self ):
309+ base_model = self .build_model (
310+ op_type = "Conv" ,
311+ input_shape = ir .Shape (("N" , 32 , 14 )),
312+ weight_shape = (32 , 11 , 4 ),
313+ pad_inputs = [ir .tensor ([0 , 0 , 0 , 0 , 0 , 0 ], name = "pads" )],
314+ conv_attributes = {"auto_pad" : "VALID" },
315+ )
316+ # Drop convolutional input shape
317+ base_model .graph [0 ].outputs [0 ].shape = None
318+ onnx_checker .CheckerPass (True )(base_model )
319+
320+ # Apply rule and check it was not applied
321+ tracer = orp .MatchingTracer ()
322+ count = normalize_pad_format_conv .apply_to_model (base_model , tracer = tracer )
323+ self .assertEqual (count , 0 )
324+
325+ # Check that the error message is the expected one
326+ tracer_match = tracer .best_matches_map [normalize_pad_format_conv ][0 ]
327+ self .assertEqual (tracer_match .status .value , orp .MatchStatus .CONDITION_FAILED )
328+ self .assertRegex (tracer_match .match_result .reason , "Input shapes are not defined" )
329+
330+
263331if __name__ == "__main__" :
264332 unittest .main ()
0 commit comments