2020 is_channel_last_dim_order ,
2121 is_contiguous_dim_order ,
2222)
23+ from executorch .exir .pass_base import ExportPass
24+
25+ from exir .passes .memory_format_ops_pass import MemoryFormatOpsPass
2326
2427from torch .export import export
28+
29+ from torch .fx .passes .infra .pass_manager import PassManager
2530from torch .testing import FileCheck
2631from torch .utils ._pytree import tree_flatten
2732
@@ -99,6 +104,50 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
99104 return t1 * t2
100105
101106
107+ def assert_unambiguous_dim_order (gm ):
108+ # This is just an example, you can add your own pass or passes.
109+ class ExampleNOPPass (ExportPass ):
110+ """
111+ Does nothing!
112+ """
113+
114+ def call_operator (self , op , args , kwargs , meta ):
115+ return super ().call_operator (
116+ op ,
117+ args ,
118+ kwargs ,
119+ meta ,
120+ )
121+
122+ # This is an example of how one can detect ambiguous dim_order anywhere in the graph.
123+ # You can be surgical and only detect it in the nodes you are interested in or something else.
124+ def detect_ambiguity (gm ):
125+ """
126+ Check every node's output tensor dim_order and raise if it is ambiguous for a list of formats.
127+ """
128+ for node in gm .graph .nodes :
129+ if node .op == "call_function" :
130+ tensor = node .meta ["val" ]
131+ # Let's make sure dim_order is not ambiguous, raise otherwise.
132+ # This is raising because we can't do anything about it.
133+ # The right course of follow up action is to ask user to try with a different example input.
134+ print (f"node: { node } , shape: { tensor .shape } , " , end = "" )
135+
136+ try :
137+ dim_order = tensor .dim_order (
138+ ambiguity_check = [torch .contiguous_format , torch .channels_last ]
139+ )
140+ print (f"dim_order: { dim_order } " )
141+ except Exception as e :
142+ print ("" )
143+ raise RuntimeError (e )
144+
145+ # any pass or passes, just using MemoryFormatOpsPass as an example
146+ dim_order_pass_manager = PassManager (passes = [ExampleNOPPass ()])
147+ dim_order_pass_manager .add_checks (detect_ambiguity )
148+ dim_order_pass_manager (gm )
149+
150+
102151class MemoryFormatOpsPassTestUtils :
103152 @staticmethod
104153 def memory_format_test_runner (
@@ -121,6 +170,9 @@ def memory_format_test_runner(
121170 before , compile_config = EdgeCompileConfig (_skip_dim_order = False )
122171 )
123172
173+ # Just as an example
174+ assert_unambiguous_dim_order (epm .exported_program ().graph_module )
175+
124176 # check memory format ops, if needed
125177 if test_set .op_level_check :
126178 aten_op_str , edge_op_str = MemoryFormatOps2Str [test_set .op ]
@@ -153,6 +205,8 @@ def memory_format_test_runner(
153205
154206 # check EdgeOp and the new BackendOp should behave the same in the runtime
155207 executorch_prog = epm .to_executorch ()
208+ if test_set ._load_for_executorch_from_buffer is None :
209+ return
156210
157211 executorch_module = test_set ._load_for_executorch_from_buffer (
158212 executorch_prog .buffer
0 commit comments