9
9
import logging
10
10
11
11
import torch
12
- from executorch .backends .arm ._passes import AnnotateOutputDimOrderPass
12
+ from executorch .backends .arm ._passes .annotate_decomposed_matmul import (
13
+ AnnotateDecomposedMatmulPass ,
14
+ )
13
15
from executorch .backends .arm ._passes .arm_pass_utils import (
14
16
create_node ,
15
17
get_first_fake_tensor ,
16
- get_output_dim_orders ,
17
18
is_param_node ,
18
19
)
20
+ from executorch .backends .arm .constants import (
21
+ HWCM_ORDER ,
22
+ NCHW_ORDER ,
23
+ NHWC_INVERSE_ORDER ,
24
+ NHWC_ORDER ,
25
+ NNCHW_ORDER ,
26
+ NNHWC_INVERSE_ORDER ,
27
+ NNHWC_ORDER ,
28
+ )
19
29
from executorch .exir import ExportedProgram
20
30
from executorch .exir .dialects ._ops import ops as exir_ops
21
31
from executorch .exir .pass_base import ExportPass , PassResult
@@ -38,12 +48,6 @@ class ToTosaMemoryFormatPass(ExportPass):
38
48
The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
39
49
"""
40
50
41
- NHWC_order = (0 , 2 , 3 , 1 )
42
- NHWC_inverse_order = (0 , 3 , 1 , 2 )
43
- HWCM_order = (2 , 3 , 0 , 1 )
44
- NNHWC_order = (0 , 1 , 3 , 4 , 2 )
45
- NNHWC_inverse_order = (0 , 1 , 4 , 2 , 3 )
46
-
47
51
def __init__ (self , exported_program : ExportedProgram ) -> None :
48
52
self .exported_program = exported_program
49
53
super ().__init__ ()
@@ -135,9 +139,9 @@ def insert_input_transpose(node, input_node, graph_module):
135
139
args = (
136
140
input_node ,
137
141
list (
138
- ToTosaMemoryFormatPass . NNHWC_inverse_order
142
+ NNHWC_INVERSE_ORDER
139
143
if len (get_first_fake_tensor (input_node ).size ()) == 5
140
- else ToTosaMemoryFormatPass . NHWC_inverse_order
144
+ else NHWC_INVERSE_ORDER
141
145
),
142
146
),
143
147
from_node = node ,
@@ -157,18 +161,18 @@ def insert_output_transpose(node, graph_module):
157
161
args = (
158
162
node ,
159
163
list (
160
- ToTosaMemoryFormatPass . NNHWC_order
164
+ NNHWC_ORDER
161
165
if len (get_first_fake_tensor (node ).size ()) == 5
162
- else ToTosaMemoryFormatPass . NHWC_order
166
+ else NHWC_ORDER
163
167
),
164
168
),
165
169
from_node = node ,
166
170
)
167
171
168
172
permute_node .meta ["tosa_dim_order" ] = (
169
- ToTosaMemoryFormatPass . NNHWC_order
173
+ NNHWC_ORDER
170
174
if len (get_first_fake_tensor (node ).size ()) == 5
171
- else ToTosaMemoryFormatPass . NHWC_order
175
+ else NHWC_ORDER
172
176
)
173
177
node .meta ["tosa_dim_order" ] = tuple (
174
178
range (len (get_first_fake_tensor (node ).size ()))
@@ -218,7 +222,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
218
222
for node in graph_module .graph .nodes :
219
223
# call_function and placeholder allowed due to
220
224
# index.Tensor being able to come in as both
221
- if node .op not in [ "call_function" , "placeholder" , "output" ] :
225
+ if node .op != "call_function" :
222
226
continue
223
227
224
228
# Transpose views
@@ -240,21 +244,33 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
240
244
graph_module ,
241
245
)
242
246
243
- # Transpose inputs
244
- elif _is_input (node , self .exported_program ):
245
- input_shape = get_first_fake_tensor (node ).size ()
246
- if len (input_shape ) in (4 , 5 ):
247
- ToTosaMemoryFormatPass .insert_output_transpose (node , graph_module )
247
+ output_node = graph_module .graph .output_node ()
248
248
249
- # Transpose outputs
250
- elif node .op == "output" :
251
- output_shape = get_first_fake_tensor (node ).size ()
249
+ # Transpose inputs if they are in (N)NCHW format
250
+ inputs = [
251
+ n for n in graph_module .graph .nodes if _is_input (n , self .exported_program )
252
+ ]
253
+ for input_node in inputs :
254
+ input_dim_order = get_first_fake_tensor (input_node ).dim_order ()
255
+ if input_dim_order in (NCHW_ORDER , NNCHW_ORDER ):
256
+ self .insert_output_transpose (input_node , graph_module )
257
+
258
+ # Transpose outputs if they are in (N)NCHW format
259
+ outputs = output_node .args [0 ]
260
+ output_dim_orders = output_node .meta .get ("original_dim_orders" )
261
+ if output_dim_orders is None :
262
+ raise RuntimeError (
263
+ f"{ AnnotateDecomposedMatmulPass .__name__ } is required to run at the beginning of the pass pipeline when using { ToTosaMemoryFormatPass .__name__ } ."
264
+ )
252
265
253
- if len (output_shape ) in (4 , 5 ):
254
- for input_node in node .all_input_nodes :
255
- ToTosaMemoryFormatPass .insert_input_transpose (
256
- node , input_node , graph_module
257
- )
266
+ for output_node_input , output_dim_order in zip (outputs , output_dim_orders ): # type: ignore[arg-type]
267
+ if output_dim_order in (
268
+ NCHW_ORDER ,
269
+ NNCHW_ORDER ,
270
+ ):
271
+ self .insert_input_transpose (
272
+ output_node , output_node_input , graph_module
273
+ )
258
274
259
275
def remove_dim_order_kwargs (
260
276
self , graph_module : torch .fx .GraphModule , node : torch .fx .Node
@@ -277,17 +293,17 @@ def call(self, graph_module: torch.fx.GraphModule):
277
293
node_data = get_first_fake_tensor (node ).data
278
294
279
295
self .remove_dim_order_kwargs (graph_module , node )
280
- # Inputs and outputs are always in (N)NCHW format
296
+ # Inputs and outputs may vary in dim_order
281
297
if _is_input (node , self .exported_program ) or node .op == "output" :
282
- dim_order = tuple ( range ( node_data .dim ()) )
298
+ dim_order = node_data .dim_order ( )
283
299
elif node_data .dim () == 4 :
284
- dim_order = self . NHWC_order
300
+ dim_order = NHWC_ORDER
285
301
if self .is_weight_node_for_depthwise_conv2d (node ):
286
302
# The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
287
303
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
288
- dim_order = self . HWCM_order
304
+ dim_order = HWCM_ORDER
289
305
elif node_data .dim () == 5 :
290
- dim_order = self . NNHWC_order
306
+ dim_order = NNHWC_ORDER
291
307
else :
292
308
dim_order = tuple (range (node_data .dim ())) # type: ignore[assignment]
293
309
@@ -300,32 +316,3 @@ def call(self, graph_module: torch.fx.GraphModule):
300
316
graph_module = super ().call (graph_module ).graph_module
301
317
302
318
return PassResult (graph_module , True )
303
-
304
- def requires (self , graph_module ) -> None :
305
- """
306
- This is the only pass which handles dim_orders, so verify that the output dim_orders has not changed since the beginning of the lowering pipeline.
307
- """
308
-
309
- dim_orders = get_output_dim_orders (graph_module )
310
- original_dim_orders = graph_module .graph .output_node ().meta .get (
311
- "original_dim_orders"
312
- )
313
- output_node = graph_module .graph .output_node ()
314
-
315
- if original_dim_orders is None :
316
- raise RuntimeError (
317
- f"{ AnnotateOutputDimOrderPass .__name__ } must be run in the beginning of the pass pipeline to verify that the dim order has not changed unexpectedly during its run."
318
- )
319
-
320
- if len (dim_orders ) != len (original_dim_orders ):
321
- raise RuntimeError (
322
- f"The number of outputs has changed since { AnnotateOutputDimOrderPass .__name__ } was run."
323
- )
324
-
325
- for node , dim_order , original_dim_order in zip (
326
- output_node .args [0 ], dim_orders , original_dim_orders
327
- ):
328
- if dim_order != original_dim_order :
329
- raise RuntimeError (
330
- f"The dim order of output { node .name } has changed from { original_dim_order } to { dim_order } since { AnnotateOutputDimOrderPass .__name__ } was run."
331
- )
0 commit comments