@@ -414,7 +414,9 @@ def model3(X: ot.FLOAT[1, 1]):
414414
415415class ReshapeReshapeTest (unittest .TestCase ):
416416 @staticmethod
417- def create_model (input_shape , shape1 , shape2 ):
417+ def create_model (
418+ input_shape , shape1 , shape2 , allowzero1 = 0 , allowzero2 = 0 , infer_shape = False
419+ ):
418420 def _convert_shape (shape , name ):
419421 if isinstance (shape , np .ndarray ):
420422 shape = tape .initializer (ir .Tensor (shape , name = name ))
@@ -430,20 +432,43 @@ def _convert_shape(shape, name):
430432 tape = ir .tape .Tape (ir .Graph ([x ], [y ], nodes = [], opset_imports = {"" : 20 }))
431433
432434 # Build the graph.
433- reshape = tape .op ("Reshape" , inputs = [x , _convert_shape (shape1 , "shape_" )])
434- tape .op ("Reshape" , inputs = [reshape , _convert_shape (shape2 , "shape" )], output = y )
435+ reshape = tape .op (
436+ "Reshape" ,
437+ inputs = [x , _convert_shape (shape1 , "shape_" )],
438+ attributes = {"allowzero" : allowzero1 },
439+ )
440+ tape .op (
441+ "Reshape" ,
442+ inputs = [reshape , _convert_shape (shape2 , "shape" )],
443+ attributes = {"allowzero" : allowzero2 },
444+ output = y ,
445+ )
435446 model = ir .Model (tape .graph_like , ir_version = 10 )
447+
448+ # Infer shapes.
449+ if infer_shape :
450+ model = ir .passes .common .ShapeInferencePass ()(model ).model
436451 return model
437452
438453 @parameterized .parameterized .expand (
439454 [
440455 ((3 , 4 , 5 ), [4 , 5 , 3 ], [5 , 4 , 3 ]),
441456 ((3 , 4 , 5 ), [4 , 5 , 3 ], [5 , 4 , 3 ]),
457+ ((3 , 4 , 8 ), [2 , 0 , 3 , - 1 ], [0 , 3 , 2 , 8 ]),
458+ ((3 , 4 , 8 ), [3 , 4 , - 1 ], [- 1 , 12 ], 1 ),
459+ ((3 , 4 , 2 ), [0 , 4 , - 1 ], [12 , - 1 ], 0 , 1 ),
460+ ((3 , 0 , 8 ), [4 , 2 , 0 , 0 ], [3 , 0 ], 1 , 1 ),
442461 ]
443462 )
444- def test_reshape_reshape_rule (self , input_shape , shape1 , shape2 ):
463+ def test_reshape_reshape_rule (
464+ self , input_shape , shape1 , shape2 , allowzero1 = 0 , allowzero2 = 0
465+ ):
445466 model = self .create_model (
446- input_shape , np .array (shape1 , dtype = "int64" ), np .array (shape2 , dtype = "int64" )
467+ input_shape ,
468+ np .array (shape1 , dtype = "int64" ),
469+ np .array (shape2 , dtype = "int64" ),
470+ allowzero1 = allowzero1 ,
471+ allowzero2 = allowzero2 ,
447472 )
448473 updated_model = clone_model (model )
449474
@@ -456,19 +481,64 @@ def test_reshape_reshape_rule(self, input_shape, shape1, shape2):
456481 inputs = np .random .default_rng (10 ).random (input_shape , dtype = "float32" )
457482 testing .assert_numerically_equal (model , updated_model , (inputs ,), atol = 0 , rtol = 0 )
458483
484+ @parameterized .parameterized .expand ([([3 , 2 , 3 , 3 , 3 ], 1 ), ([0 , - 1 , 3 , 2 ], 0 )])
485+ def test_reshape_dynamic_reshape_rule (self , shape1 , allowzero1 = 0 ):
486+ input_shape = (3 , 6 , 9 )
487+ shape1 = np .array (shape1 , dtype = "int64" )
488+ # Build the model with unknown shape1.
489+ model = self .create_model (
490+ input_shape ,
491+ (shape1 .size ,),
492+ np .array ((1 , 6 , 27 ), dtype = "int64" ),
493+ allowzero1 = allowzero1 ,
494+ )
495+ updated_model = clone_model (model )
496+
497+ # check rewrite approach.
498+ count = basic_rules .reshape_reshape_rule .apply_to_model (updated_model )
499+ self .assertEqual (count , 1 )
500+ self .assertEqual (["Reshape" ], [n .op_type for n in updated_model .graph ])
501+
502+ # Check inference.
503+ feeds = {
504+ "X" : np .random .default_rng (2 ).random (input_shape , dtype = "float32" ),
505+ "shape_" : shape1 ,
506+ }
507+ testing .assert_numerically_equal (model , updated_model , feeds , atol = 0 , rtol = 0 )
508+
509+ @parameterized .parameterized .expand (
510+ [((3 , 6 , 9 ), [0 , 3 , 2 , - 1 ]), ((0 , 6 , 2 ), [0 , 0 , 3 ], 1 )]
511+ )
512+ def test_reshape_reshape_dynamic_rule (self , input_shape , shape2 , allowzero2 = 0 ):
513+ # Note that shape inference is required for this test to be valid.
514+ shape2 = np .array (shape2 , dtype = "int64" )
515+ model = self .create_model (
516+ input_shape ,
517+ np .array ((3 , 2 , - 1 ), dtype = "int64" ),
518+ shape2 ,
519+ allowzero2 = allowzero2 ,
520+ infer_shape = True ,
521+ )
522+ updated_model = clone_model (model )
523+
524+ # check rewrite approach.
525+ count = basic_rules .reshape_reshape_rule .apply_to_model (updated_model )
526+ self .assertEqual (count , 1 )
527+ self .assertEqual (["Reshape" ], [n .op_type for n in updated_model .graph ])
528+
529+ # Check inference.
530+ inputs = np .random .default_rng (7 ).random (input_shape , dtype = "float32" )
531+ testing .assert_numerically_equal (model , updated_model , (inputs ,), atol = 0 , rtol = 0 )
532+
459533 @parameterized .parameterized .expand (
460534 [
461- ((2 ,), np .array ([1 , 6 ], dtype = "int64" ), "ignored is not a constant" ),
462- (np .array ([1 , 6 ], dtype = "int64" ), (3 ,), "is not a constant" ),
463- (
464- np .array ([1 , 6 ], dtype = "int64" ),
465- np .array ([0 , 6 ], dtype = "int64" ),
466- "non-positive values" ,
467- ),
535+ ((3 ,), "is not a constant" ),
536+ (np .array ([0 , - 1 ], dtype = "int64" ), "both 0 and -1 dimensions" ),
537+ (np .array ([0 , 0 , 3 ], dtype = "int64" ), "more than one 0 dimension" ),
468538 ]
469539 )
470- def test_unsupported_reshape_reshape (self , shape1 , shape2 , error_msg ):
471- model = self .create_model ((1 , 2 , 3 ), shape1 , shape2 )
540+ def test_unsupported_reshape_reshape (self , shape2 , error_msg ):
541+ model = self .create_model ((1 , 2 , 3 ), np . array ([ 1 , 6 ], dtype = "int64" ) , shape2 )
472542
473543 # Check rewrite approach.
474544 tracer = MatchingTracer ()
0 commit comments