@@ -268,24 +268,73 @@ def test_fuse_pad_into_conv_integer(
268268
269269
270270class NormalizePadFormatTest (FusePadConvBaseTest ):
271+ def build_model (
272+ self ,
273+ input_shape : ir .Shape ,
274+ conv_inputs : Sequence [int ],
275+ conv_attributes : Mapping [str , ir .Attr ] | None = None ,
276+ infer_shapes = True ,
277+ ) -> ir .Model :
278+ tape = ir .tape .Tape ()
279+ inputs = []
280+ output_shape = ir .Shape (("?" ,) * len (input_shape ))
281+
282+ # Convert conv_inputs to initializers (if needed)
283+ conv_inputs = list (conv_inputs )
284+ for idx , x in enumerate (conv_inputs ):
285+ if isinstance (x , ir .TensorProtocol ):
286+ conv_inputs [idx ] = tape .initializer (x )
287+ elif isinstance (x , ir .Value ):
288+ inputs .append (x )
289+ elif x is not None :
290+ raise ValueError (f"Unsupported type for pad input ({ x } ): { type (x )} ." )
291+
292+ # Register operations in the tape
293+ x = ir .Input ("X" , shape = input_shape , type = ir .TensorType (ir .DataType .FLOAT ))
294+ y = tape .op (
295+ "Conv" ,
296+ inputs = [x , * conv_inputs ],
297+ attributes = conv_attributes ,
298+ output = ir .Input ("Y" , shape = output_shape , type = x .type ),
299+ )
300+
301+ # Build the model
302+ ir_model = ir .Model (
303+ ir .Graph (
304+ inputs = [x , * inputs ],
305+ outputs = [y ],
306+ nodes = tape .nodes ,
307+ initializers = tape .initializers ,
308+ opset_imports = {"" : 20 },
309+ name = "model" ,
310+ ),
311+ ir_version = 10 ,
312+ )
313+ if len (input_shape ) > 0 and infer_shapes :
314+ onnx_checker .CheckerPass (True )(ir_model )
315+ ir_model = shape_inference .infer_shapes (ir_model )
316+ else :
317+ onnx_checker .CheckerPass (False )(ir_model )
318+ return ir_model
319+
271320 @parameterized .parameterized .expand (
272321 [
273- (strides , kernel_shape , auto_pad )
322+ (dynamic_shape , strides , kernel_shape , auto_pad )
274323 for strides , kernel_shape in [((2 , 3 ), (1 , 4 )), ((2 , 1 ), (5 , 2 ))]
275- for auto_pad in ["SAME_UPPER" , "SAME_LOWER" , "VALID" ]
324+ for dynamic_shape , auto_pad in [
325+ (False , "SAME_UPPER" ),
326+ (False , "SAME_LOWER" ),
327+ (True , "VALID" ),
328+ ]
276329 ]
277330 )
278- def test_normalize_pad_format (self , strides , kernel_shape , auto_pad ):
279- pad_inputs = [
280- ir .tensor ([1 , 1 , 1 , 1 ], name = "pads" ),
281- None ,
282- ir .tensor ([2 , 3 ], name = "axes" ),
283- ]
331+ def test_normalize_pad_format (self , dynamic_shape , strides , kernel_shape , auto_pad ):
332+ input_shape = (
333+ ir .Shape (("N" , "A" , "B" , "C" )) if dynamic_shape else ir .Shape (("N" , 32 , 22 , 27 ))
334+ )
284335 base_model = self .build_model (
285- op_type = "Conv" ,
286- input_shape = ir .Shape (("N" , 32 , 22 , 27 )),
287- weight_shape = (32 , 32 , * kernel_shape ),
288- pad_inputs = pad_inputs ,
336+ input_shape = input_shape ,
337+ conv_inputs = [ir .tensor (self .get_conv_weights ((32 , 32 , * kernel_shape )), name = "W" )],
289338 conv_attributes = {
290339 "strides" : strides ,
291340 "auto_pad" : auto_pad ,
@@ -296,27 +345,51 @@ def test_normalize_pad_format(self, strides, kernel_shape, auto_pad):
296345
297346 # Apply rule
298347 count = fuse_pad_into_conv_rule_set ().apply_to_model (updated_model )
299-
300- # Check that Pad was fused
301- self .assertEqual (count , 2 )
302- self .assertEqual (updated_model .graph .num_nodes (), 1 )
303348 onnx_checker .CheckerPass (True )(updated_model )
304349
350+ # Check conv has changed
351+ self .assertEqual (count , 1 )
352+ self .assertEqual (updated_model .graph [0 ].attributes .get_string ("auto_pad" ), "NOTSET" )
353+
305354 # Check inference
306355 inputs = self .rng .random ((1 , 32 , 22 , 27 ), dtype = "float32" )
307356 testing .assert_numerically_equal (base_model , updated_model , (inputs ,), atol = 0 , rtol = 0 )
308357
309- def test_unsupported_normalize_pad_format (self ):
358+ @parameterized .parameterized .expand (
359+ [
360+ (ir .Shape ([]), False , "Input shapes are not defined" ),
361+ (ir .Shape (("N" , "C" , "A" )), False , "Expected static spatial input shapes" ),
362+ (ir .Shape (("N" , "C" , 32 )), False , "Expected static spatial output shapes" ),
363+ ]
364+ )
365+ def test_unsupported_normalize_pad_format (self , input_shape , infer_shapes , error_msg ):
310366 base_model = self .build_model (
311- op_type = "Conv" ,
312- input_shape = ir .Shape (("N" , 32 , 14 )),
313- weight_shape = (32 , 11 , 4 ),
314- pad_inputs = [ir .tensor ([0 , 0 , 0 , 0 , 0 , 0 ], name = "pads" )],
315- conv_attributes = {"auto_pad" : "VALID" },
367+ input_shape = input_shape ,
368+ conv_inputs = [ir .tensor (np .ones ((32 , 11 , 4 )), name = "W" )],
369+ conv_attributes = {"auto_pad" : "SAME_UPPER" },
370+ infer_shapes = infer_shapes ,
371+ )
372+
373+ # Apply rule and check it was not applied
374+ tracer = orp .MatchingTracer ()
375+ count = normalize_pad_format_conv .apply_to_model (base_model , tracer = tracer )
376+ self .assertEqual (count , 0 )
377+
378+ # Check that the error message is the expected one
379+ tracer_match = tracer .best_matches_map [normalize_pad_format_conv ][0 ]
380+ self .assertEqual (tracer_match .status .value , orp .MatchStatus .CONDITION_FAILED )
381+ self .assertRegex (tracer_match .match_result .reason , error_msg )
382+
383+ def test_unsupported_normalize_pad_format_on_weights (self ):
384+ W = ir .Value (name = "W" , shape = ir .Shape ([]), type = ir .TensorType (ir .DataType .FLOAT ))
385+ base_model = self .build_model (
386+ input_shape = ir .Shape (("N" , 2 , 32 )),
387+ conv_inputs = [W ],
388+ conv_attributes = {"auto_pad" : "SAME_UPPER" },
389+ infer_shapes = False ,
316390 )
317- # Drop convolutional input shape
318- base_model .graph [0 ].outputs [0 ].shape = None
319- onnx_checker .CheckerPass (True )(base_model )
391+ # Set output shape to analyze error due to weights
392+ base_model .graph [0 ].outputs [0 ].shape = ir .Shape (("N" , 10 , 32 ))
320393
321394 # Apply rule and check it was not applied
322395 tracer = orp .MatchingTracer ()
@@ -326,7 +399,7 @@ def test_unsupported_normalize_pad_format(self):
326399 # Check that the error message is the expected one
327400 tracer_match = tracer .best_matches_map [normalize_pad_format_conv ][0 ]
328401 self .assertEqual (tracer_match .status .value , orp .MatchStatus .CONDITION_FAILED )
329- self .assertRegex (tracer_match .match_result .reason , "Input shapes are not defined " )
402+ self .assertRegex (tracer_match .match_result .reason , "same length than kernel_shape " )
330403
331404
332405if __name__ == "__main__" :
0 commit comments