@@ -551,5 +551,65 @@ def test_unsupported_reshape_reshape(self, shape2, error_msg):
551551 self .assertRegex (tracer_match .match_result .reason , error_msg )
552552
553553
554+ class Flatten2ReshapeTest (unittest .TestCase ):
555+ @staticmethod
556+ def create_model (input_shape , axis = 1 ):
557+ x = ir .Input ("X" , ir .Shape (input_shape ), ir .TensorType (ir .DataType .FLOAT ))
558+ y = ir .Input ("Y" , type = ir .TensorType (ir .DataType .FLOAT ))
559+ tape = ir .tape .Tape (ir .Graph ([x ], [y ], nodes = [], opset_imports = {"" : 20 }))
560+
561+ # Build the graph.
562+ tape .op ("Flatten" , inputs = [x ], attributes = {"axis" : axis }, output = y )
563+ model = ir .Model (tape .graph_like , ir_version = 10 )
564+ return model
565+
566+ @parameterized .parameterized .expand (list (range (- 5 , 6 )))
567+ def test_flatten_to_reshape_rule (self , axis ):
568+ input_shape = (1 , 4 , 8 , 7 , 5 )
569+ model = self .create_model (input_shape = input_shape , axis = axis )
570+ updated_model = clone_model (model )
571+
572+ # check rewrite approach.
573+ count = basic_rules .flatten_to_reshape_rule .apply_to_model (updated_model )
574+ self .assertEqual (count , 1 )
575+ self .assertEqual (["Reshape" ], [n .op_type for n in updated_model .graph ])
576+
577+ # Check inference.
578+ inputs = np .random .default_rng (13 ).random (input_shape , dtype = "float32" )
579+ testing .assert_numerically_equal (model , updated_model , (inputs ,), atol = 0 , rtol = 0 )
580+
581+ @parameterized .parameterized .expand (list (range (- 4 , 5 )))
582+ def test_flatten_to_reshape_dynamic_input (self , axis ):
583+ model = self .create_model (input_shape = ("N" , "C1" , "C2" , "C3" ), axis = axis )
584+ # Rule is supported in all cases if the output shape is known for non-special cases.
585+ input_shape = (1 , 2 , 3 , 4 )
586+ if axis not in {- 3 , 0 , 1 , 4 }:
587+ out_shape = ir .Shape ((np .prod (input_shape [:axis ]), np .prod (input_shape [axis :])))
588+ model .graph .outputs [0 ].shape = out_shape
589+ updated_model = clone_model (model )
590+
591+ # check rewrite approach.
592+ count = basic_rules .flatten_to_reshape_rule .apply_to_model (updated_model )
593+ self .assertEqual (count , 1 )
594+ self .assertEqual (["Reshape" ], [n .op_type for n in updated_model .graph ])
595+
596+ # Check inference.
597+ inputs = np .random .default_rng (17 ).random (input_shape , dtype = "float32" )
598+ testing .assert_numerically_equal (model , updated_model , (inputs ,), atol = 0 , rtol = 0 )
599+
600+ def test_unsupported_flatten_to_reshape (self ):
601+ model = self .create_model (input_shape = ("N" , "C1" , "C2" ), axis = 2 )
602+
603+ # Check rewrite approach.
604+ tracer = MatchingTracer ()
605+ count = basic_rules .flatten_to_reshape_rule .apply_to_model (model , tracer = tracer )
606+ self .assertEqual (count , 0 )
607+
608+ # Check that the error message is the expected one
609+ tracer_match = tracer .best_matches_map [basic_rules .flatten_to_reshape_rule ][0 ]
610+ self .assertEqual (tracer_match .status .value , orp .MatchStatus .CONDITION_FAILED )
611+ self .assertRegex (tracer_match .match_result .reason , "Impossible to compute new shape" )
612+
613+
554614if __name__ == "__main__" :
555615 unittest .main (verbosity = 2 )
0 commit comments