2020 is_channel_last_dim_order ,
2121 is_contiguous_dim_order ,
2222)
23+ from executorch .exir .pass_base import ExportPass
2324
2425from torch .export import export
26+
27+ from torch .fx .passes .infra .pass_manager import PassManager
2528from torch .testing import FileCheck
2629from torch .utils ._pytree import tree_flatten
2730
@@ -99,10 +102,67 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99102 return t1 * t2
100103
101104
105+ class AmbiguousDimOrderError (RuntimeError ):
106+ pass
107+
108+
109+ def assert_unambiguous_dim_order (gm ):
110+ class ExampleNOPPass (ExportPass ):
111+ """
112+ Does nothing!
113+ """
114+
115+ def call_operator (self , op , args , kwargs , meta ):
116+ return super ().call_operator (
117+ op ,
118+ args ,
119+ kwargs ,
120+ meta ,
121+ )
122+
123+ # This is an example of how one can detect ambiguous dim_order anywhere in the graph.
124+ # You can be surgical and only detect it in the nodes you are interested in or something else.
125+ def detect_ambiguity (gm ):
126+ """
127+ Check every node's output tensor dim_order and raise if it is ambiguous for a list of formats.
128+ """
129+
130+ def get_tensors (node : torch .fx .Node ) -> torch .Tensor :
131+ val = node .meta ["val" ]
132+ if isinstance (val , torch .Tensor ):
133+ return [val ]
134+ elif isinstance (val , (list , tuple )):
135+ return [tensor for tensor in val if isinstance (tensor , torch .Tensor )]
136+ return []
137+
138+ for node in gm .graph .nodes :
139+ if node .op == "call_function" :
140+ for tensor in get_tensors (node ):
141+ # Let's make sure dim_order is not ambiguous, raise otherwise.
142+ # This is raising because we can't do anything about it.
143+ # The right course of follow up action is to ask user to try with a different example input.
144+ try :
145+ _ = tensor .dim_order (
146+ ambiguity_check = [
147+ torch .contiguous_format ,
148+ torch .channels_last ,
149+ ]
150+ )
151+ except Exception :
152+ raise AmbiguousDimOrderError
153+
154+ # any pass or passes, just using MemoryFormatOpsPass as an example
155+ dim_order_pass_manager = PassManager (passes = [ExampleNOPPass ()])
156+ dim_order_pass_manager .add_checks (detect_ambiguity )
157+ dim_order_pass_manager (gm )
158+
159+
102160class MemoryFormatOpsPassTestUtils :
103161 @staticmethod
104162 def memory_format_test_runner (
105- test_class : unittest .TestCase , test_set : MemoryFormatTestSet
163+ test_class : unittest .TestCase ,
164+ test_set : MemoryFormatTestSet ,
165+ check_unambiguous_dim_order : bool = False ,
106166 ):
107167 before = export (
108168 test_set .module , test_set .sample_input , strict = True
@@ -121,6 +181,9 @@ def memory_format_test_runner(
121181 before , compile_config = EdgeCompileConfig (_skip_dim_order = False )
122182 )
123183
184+ if check_unambiguous_dim_order :
185+ assert_unambiguous_dim_order (epm .exported_program ().graph_module )
186+
124187 # check memory format ops, if needed
125188 if test_set .op_level_check :
126189 aten_op_str , edge_op_str = MemoryFormatOps2Str [test_set .op ]
0 commit comments