2020    is_channel_last_dim_order ,
2121    is_contiguous_dim_order ,
2222)
23+ from  executorch .exir .pass_base  import  ExportPass 
2324
2425from  torch .export  import  export 
26+ 
27+ from  torch .fx .passes .infra .pass_manager  import  PassManager 
2528from  torch .testing  import  FileCheck 
2629from  torch .utils ._pytree  import  tree_flatten 
2730
@@ -99,10 +102,67 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99102        return  t1  *  t2 
100103
101104
105+ class  AmbiguousDimOrderError (RuntimeError ):
106+     pass 
107+ 
108+ 
109+ def  assert_unambiguous_dim_order (gm ):
110+     class  ExampleNOPPass (ExportPass ):
111+         """ 
112+         Does nothing! 
113+         """ 
114+ 
115+         def  call_operator (self , op , args , kwargs , meta ):
116+             return  super ().call_operator (
117+                 op ,
118+                 args ,
119+                 kwargs ,
120+                 meta ,
121+             )
122+ 
123+     # This is an example of how one can detect ambiguous dim_order anywhere in the graph. 
124+     # You can be surgical and only detect it in the nodes you are interested in or something else. 
125+     def  detect_ambiguity (gm ):
126+         """ 
127+         Check every node's output tensor dim_order and raise if it is ambiguous for a list of formats. 
128+         """ 
129+ 
130+         def  get_tensors (node : torch .fx .Node ) ->  torch .Tensor :
131+             val  =  node .meta ["val" ]
132+             if  isinstance (val , torch .Tensor ):
133+                 return  [val ]
134+             elif  isinstance (val , (list , tuple )):
135+                 return  [tensor  for  tensor  in  val  if  isinstance (tensor , torch .Tensor )]
136+             return  []
137+ 
138+         for  node  in  gm .graph .nodes :
139+             if  node .op  ==  "call_function" :
140+                 for  tensor  in  get_tensors (node ):
141+                     # Let's make sure dim_order is not ambiguous, raise otherwise. 
142+                     # This is raising because we can't do anything about it. 
143+                     # The right course of follow up action is to ask user to try with a different example input. 
144+                     try :
145+                         _  =  tensor .dim_order (
146+                             ambiguity_check = [
147+                                 torch .contiguous_format ,
148+                                 torch .channels_last ,
149+                             ]
150+                         )
151+                     except  Exception :
152+                         raise  AmbiguousDimOrderError 
153+ 
154+     # any pass or passes, just using MemoryFormatOpsPass as an example 
155+     dim_order_pass_manager  =  PassManager (passes = [ExampleNOPPass ()])
156+     dim_order_pass_manager .add_checks (detect_ambiguity )
157+     dim_order_pass_manager (gm )
158+ 
159+ 
102160class  MemoryFormatOpsPassTestUtils :
103161    @staticmethod  
104162    def  memory_format_test_runner (
105-         test_class : unittest .TestCase , test_set : MemoryFormatTestSet 
163+         test_class : unittest .TestCase ,
164+         test_set : MemoryFormatTestSet ,
165+         check_unambiguous_dim_order : bool  =  False ,
106166    ):
107167        before  =  export (
108168            test_set .module , test_set .sample_input , strict = True 
@@ -121,6 +181,9 @@ def memory_format_test_runner(
121181                before , compile_config = EdgeCompileConfig (_skip_dim_order = False )
122182            )
123183
184+         if  check_unambiguous_dim_order :
185+             assert_unambiguous_dim_order (epm .exported_program ().graph_module )
186+ 
124187        # check memory format ops, if needed 
125188        if  test_set .op_level_check :
126189            aten_op_str , edge_op_str  =  MemoryFormatOps2Str [test_set .op ]
0 commit comments