99import logging
1010
1111import torch
12- from executorch .backends .arm ._passes import AnnotateOutputDimOrderPass
12+ from executorch .backends .arm ._passes .annotate_decomposed_matmul import (
13+ AnnotateDecomposedMatmulPass ,
14+ )
1315from executorch .backends .arm ._passes .arm_pass_utils import (
1416 create_node ,
1517 get_first_fake_tensor ,
16- get_output_dim_orders ,
1718 is_param_node ,
1819)
20+ from executorch .backends .arm .constants import (
21+ HWCM_ORDER ,
22+ NCHW_ORDER ,
23+ NHWC_INVERSE_ORDER ,
24+ NHWC_ORDER ,
25+ NNCHW_ORDER ,
26+ NNHWC_INVERSE_ORDER ,
27+ NNHWC_ORDER ,
28+ )
1929from executorch .exir import ExportedProgram
2030from executorch .exir .dialects ._ops import ops as exir_ops
2131from executorch .exir .pass_base import ExportPass , PassResult
@@ -38,12 +48,6 @@ class ToTosaMemoryFormatPass(ExportPass):
3848 The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
3949 """
4050
41- NHWC_order = (0 , 2 , 3 , 1 )
42- NHWC_inverse_order = (0 , 3 , 1 , 2 )
43- HWCM_order = (2 , 3 , 0 , 1 )
44- NNHWC_order = (0 , 1 , 3 , 4 , 2 )
45- NNHWC_inverse_order = (0 , 1 , 4 , 2 , 3 )
46-
4751 def __init__ (self , exported_program : ExportedProgram ) -> None :
4852 self .exported_program = exported_program
4953 super ().__init__ ()
@@ -135,9 +139,9 @@ def insert_input_transpose(node, input_node, graph_module):
135139 args = (
136140 input_node ,
137141 list (
138- ToTosaMemoryFormatPass . NNHWC_inverse_order
142+ NNHWC_INVERSE_ORDER
139143 if len (get_first_fake_tensor (input_node ).size ()) == 5
140- else ToTosaMemoryFormatPass . NHWC_inverse_order
144+ else NHWC_INVERSE_ORDER
141145 ),
142146 ),
143147 from_node = node ,
@@ -157,18 +161,18 @@ def insert_output_transpose(node, graph_module):
157161 args = (
158162 node ,
159163 list (
160- ToTosaMemoryFormatPass . NNHWC_order
164+ NNHWC_ORDER
161165 if len (get_first_fake_tensor (node ).size ()) == 5
162- else ToTosaMemoryFormatPass . NHWC_order
166+ else NHWC_ORDER
163167 ),
164168 ),
165169 from_node = node ,
166170 )
167171
168172 permute_node .meta ["tosa_dim_order" ] = (
169- ToTosaMemoryFormatPass . NNHWC_order
173+ NNHWC_ORDER
170174 if len (get_first_fake_tensor (node ).size ()) == 5
171- else ToTosaMemoryFormatPass . NHWC_order
175+ else NHWC_ORDER
172176 )
173177 node .meta ["tosa_dim_order" ] = tuple (
174178 range (len (get_first_fake_tensor (node ).size ()))
@@ -218,7 +222,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
218222 for node in graph_module .graph .nodes :
219223 # call_function and placeholder allowed due to
220224 # index.Tensor being able to come in as both
221- if node .op not in [ "call_function" , "placeholder" , "output" ] :
225+ if node .op != "call_function" :
222226 continue
223227
224228 # Transpose views
@@ -240,21 +244,33 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
240244 graph_module ,
241245 )
242246
243- # Transpose inputs
244- elif _is_input (node , self .exported_program ):
245- input_shape = get_first_fake_tensor (node ).size ()
246- if len (input_shape ) in (4 , 5 ):
247- ToTosaMemoryFormatPass .insert_output_transpose (node , graph_module )
247+ output_node = graph_module .graph .output_node ()
248248
249- # Transpose outputs
250- elif node .op == "output" :
251- output_shape = get_first_fake_tensor (node ).size ()
249+ # Transpose inputs if they are in (N)NCHW format
250+ inputs = [
251+ n for n in graph_module .graph .nodes if _is_input (n , self .exported_program )
252+ ]
253+ for input_node in inputs :
254+ input_dim_order = get_first_fake_tensor (input_node ).dim_order ()
255+ if input_dim_order in (NCHW_ORDER , NNCHW_ORDER ):
256+ self .insert_output_transpose (input_node , graph_module )
257+
258+ # Transpose outputs if they are in (N)NCHW format
259+ outputs = output_node .args [0 ]
260+ output_dim_orders = output_node .meta .get ("original_dim_orders" )
261+ if output_dim_orders is None :
262+ raise RuntimeError (
263+ f"{ AnnotateDecomposedMatmulPass .__name__ } is required to run at the beginning of the pass pipeline when using { ToTosaMemoryFormatPass .__name__ } ."
264+ )
252265
253- if len (output_shape ) in (4 , 5 ):
254- for input_node in node .all_input_nodes :
255- ToTosaMemoryFormatPass .insert_input_transpose (
256- node , input_node , graph_module
257- )
266+ for output_node_input , output_dim_order in zip (outputs , output_dim_orders ): # type: ignore[arg-type]
267+ if output_dim_order in (
268+ NCHW_ORDER ,
269+ NNCHW_ORDER ,
270+ ):
271+ self .insert_input_transpose (
272+ output_node , output_node_input , graph_module
273+ )
258274
259275 def remove_dim_order_kwargs (
260276 self , graph_module : torch .fx .GraphModule , node : torch .fx .Node
@@ -277,17 +293,17 @@ def call(self, graph_module: torch.fx.GraphModule):
277293 node_data = get_first_fake_tensor (node ).data
278294
279295 self .remove_dim_order_kwargs (graph_module , node )
280- # Inputs and outputs are always in (N)NCHW format
296+ # Inputs and outputs may vary in dim_order
281297 if _is_input (node , self .exported_program ) or node .op == "output" :
282- dim_order = tuple ( range ( node_data .dim ()) )
298+ dim_order = node_data .dim_order ( )
283299 elif node_data .dim () == 4 :
284- dim_order = self . NHWC_order
300+ dim_order = NHWC_ORDER
285301 if self .is_weight_node_for_depthwise_conv2d (node ):
286302 # The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
287303 # dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
288- dim_order = self . HWCM_order
304+ dim_order = HWCM_ORDER
289305 elif node_data .dim () == 5 :
290- dim_order = self . NNHWC_order
306+ dim_order = NNHWC_ORDER
291307 else :
292308 dim_order = tuple (range (node_data .dim ())) # type: ignore[assignment]
293309
@@ -300,32 +316,3 @@ def call(self, graph_module: torch.fx.GraphModule):
300316 graph_module = super ().call (graph_module ).graph_module
301317
302318 return PassResult (graph_module , True )
303-
304- def requires (self , graph_module ) -> None :
305- """
306- This is the only pass which handles dim_orders, so verify that the output dim_orders has not changed since the beginning of the lowering pipeline.
307- """
308-
309- dim_orders = get_output_dim_orders (graph_module )
310- original_dim_orders = graph_module .graph .output_node ().meta .get (
311- "original_dim_orders"
312- )
313- output_node = graph_module .graph .output_node ()
314-
315- if original_dim_orders is None :
316- raise RuntimeError (
317- f"{ AnnotateOutputDimOrderPass .__name__ } must be run in the beginning of the pass pipeline to verify that the dim order has not changed unexpectedly during its run."
318- )
319-
320- if len (dim_orders ) != len (original_dim_orders ):
321- raise RuntimeError (
322- f"The number of outputs has changed since { AnnotateOutputDimOrderPass .__name__ } was run."
323- )
324-
325- for node , dim_order , original_dim_order in zip (
326- output_node .args [0 ], dim_orders , original_dim_orders
327- ):
328- if dim_order != original_dim_order :
329- raise RuntimeError (
330- f"The dim order of output { node .name } has changed from { original_dim_order } to { dim_order } since { AnnotateOutputDimOrderPass .__name__ } was run."
331- )
0 commit comments