26
26
NNCHW_ORDER ,
27
27
NNHWC_INVERSE_ORDER ,
28
28
NNHWC_ORDER ,
29
+ NNNCHW_ORDER ,
30
+ NNNHWC_INVERSE_ORDER ,
31
+ NNNHWC_ORDER ,
29
32
)
30
33
from executorch .exir import ExportedProgram
31
34
from executorch .exir .dialects ._ops import ops as exir_ops
@@ -51,12 +54,6 @@ class ToTosaMemoryFormatPass(ExportPass):
51
54
52
55
_passes_required_after : Set [Type [ExportPass ]] = set ()
53
56
54
- NHWC_order = (0 , 2 , 3 , 1 )
55
- NHWC_inverse_order = (0 , 3 , 1 , 2 )
56
- HWCM_order = (2 , 3 , 0 , 1 )
57
- NNHWC_order = (0 , 1 , 3 , 4 , 2 )
58
- NNHWC_inverse_order = (0 , 1 , 4 , 2 , 3 )
59
-
60
57
def __init__ (self , exported_program : ExportedProgram ) -> None :
61
58
self .exported_program = exported_program
62
59
super ().__init__ ()
@@ -93,7 +90,11 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
93
90
@staticmethod
94
91
def memory_format_differs (shape ):
95
92
"""Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format"""
96
- if len (shape ) >= 5 :
93
+ if len (shape ) >= 6 :
94
+ C = shape [3 ]
95
+ H = shape [4 ]
96
+ W = shape [5 ]
97
+ elif len (shape ) == 5 :
97
98
C = shape [2 ]
98
99
H = shape [3 ]
99
100
W = shape [4 ]
@@ -112,25 +113,26 @@ def memory_format_differs(shape):
112
113
113
114
@staticmethod
114
115
def is_channel_reshape (input_shape , output_shape ):
115
- """Returns true if the reshape changes the channel dimension"""
116
- if not (
117
- (len (input_shape ) == len (output_shape ) and (len (output_shape ) in (4 , 5 )))
118
- or (len (input_shape ) == 4 and len (output_shape ) == 5 )
119
- or (len (input_shape ) == 5 and len (output_shape ) == 4 )
120
- ):
116
+ """Returns true if reshape changes the channel dimension or batch product dimension(s)"""
117
+
118
+ valid_ranks = {4 , 5 , 6 }
119
+
120
+ if not (len (input_shape ) in valid_ranks and len (output_shape ) in valid_ranks ):
121
121
return False
122
122
123
123
C_old = input_shape [- 3 ]
124
124
C_new = output_shape [- 3 ]
125
125
126
- N_new = (
127
- output_shape [0 ]
128
- if len (output_shape ) == 4
129
- else output_shape [0 ] * output_shape [1 ]
130
- )
131
- N_old = (
132
- input_shape [0 ] if len (input_shape ) == 4 else input_shape [0 ] * input_shape [1 ]
133
- )
126
+ def get_batch_prod_dim (shape ):
127
+ product = 1
128
+
129
+ for dim in shape [:- 3 ]:
130
+ product = product * dim
131
+
132
+ return product
133
+
134
+ N_old = get_batch_prod_dim (input_shape )
135
+ N_new = get_batch_prod_dim (output_shape )
134
136
135
137
return (N_old != N_new ) or (C_old != C_new )
136
138
@@ -141,17 +143,27 @@ def insert_input_transpose(node, input_node, graph_module):
141
143
node .replace_input_with (input_node , pre_permute_node )
142
144
return
143
145
146
+ if len (get_first_fake_tensor (input_node ).size ()) == 6 :
147
+ mem_format = NNNHWC_INVERSE_ORDER
148
+ elif len (get_first_fake_tensor (input_node ).size ()) == 5 :
149
+ mem_format = NNHWC_INVERSE_ORDER
150
+ else :
151
+ mem_format = NHWC_INVERSE_ORDER
152
+ # Guard: mem_format must be a true permutation for the current rank
153
+ _rank_ = len (
154
+ get_first_fake_tensor (input_node ).size ()
155
+ ) # or (node) in output path
156
+ assert sorted (mem_format ) == list (
157
+ range (_rank_ )
158
+ ), f"bad perm { mem_format } for rank { _rank_ } in insert_input_transpose"
159
+
144
160
with graph_module .graph .inserting_before (node ):
145
161
permute_node = create_node (
146
162
graph_module .graph ,
147
163
exir_ops .backend .tosa .TRANSPOSE .default ,
148
164
args = (
149
165
input_node ,
150
- list (
151
- NNHWC_INVERSE_ORDER
152
- if len (get_first_fake_tensor (input_node ).size ()) == 5
153
- else NHWC_INVERSE_ORDER
154
- ),
166
+ list (mem_format ),
155
167
),
156
168
from_node = node ,
157
169
)
@@ -163,26 +175,38 @@ def insert_input_transpose(node, input_node, graph_module):
163
175
164
176
@staticmethod
165
177
def insert_output_transpose (node , graph_module ):
178
+
179
+ if len (get_first_fake_tensor (node ).size ()) == 6 :
180
+ mem_format = NNNHWC_ORDER
181
+ elif len (get_first_fake_tensor (node ).size ()) == 5 :
182
+ mem_format = NNHWC_ORDER
183
+ else :
184
+ mem_format = NHWC_ORDER
185
+ # Guard: mem_format must be a true permutation for the current rank
186
+ _rank_ = len (get_first_fake_tensor (node ).size ()) # or (node) in output path
187
+ assert sorted (mem_format ) == list (
188
+ range (_rank_ )
189
+ ), f"bad perm { mem_format } for rank { _rank_ } in insert_input_transpose"
190
+
166
191
with graph_module .graph .inserting_after (node ):
167
192
permute_node = create_node (
168
193
graph_module .graph ,
169
194
exir_ops .backend .tosa .TRANSPOSE .default ,
170
195
args = (
171
196
node ,
172
- list (
173
- NNHWC_ORDER
174
- if len (get_first_fake_tensor (node ).size ()) == 5
175
- else NHWC_ORDER
176
- ),
197
+ list (mem_format ),
177
198
),
178
199
from_node = node ,
179
200
)
180
201
181
- permute_node .meta ["tosa_dim_order" ] = (
182
- NNHWC_ORDER
183
- if len (get_first_fake_tensor (node ).size ()) == 5
184
- else NHWC_ORDER
185
- )
202
+ rank = len (get_first_fake_tensor (node ).size ())
203
+ if rank == 6 :
204
+ permute_node .meta ["tosa_dim_order" ] = NNNHWC_ORDER
205
+ elif rank == 5 :
206
+ permute_node .meta ["tosa_dim_order" ] = NNHWC_ORDER
207
+ else :
208
+ permute_node .meta ["tosa_dim_order" ] = NHWC_ORDER
209
+
186
210
node .meta ["tosa_dim_order" ] = tuple (
187
211
range (len (get_first_fake_tensor (node ).size ()))
188
212
)
@@ -261,7 +285,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
261
285
]
262
286
for input_node in inputs :
263
287
input_dim_order = get_first_fake_tensor (input_node ).dim_order ()
264
- if input_dim_order in (NCHW_ORDER , NNCHW_ORDER ):
288
+ if input_dim_order in (NCHW_ORDER , NNCHW_ORDER , NNNCHW_ORDER ):
265
289
self .insert_output_transpose (input_node , graph_module )
266
290
267
291
# Transpose outputs if they are in (N)NCHW format
@@ -276,6 +300,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
276
300
if output_dim_order in (
277
301
NCHW_ORDER ,
278
302
NNCHW_ORDER ,
303
+ NNNCHW_ORDER ,
279
304
):
280
305
self .insert_input_transpose (
281
306
output_node , output_node_input , graph_module
@@ -313,6 +338,8 @@ def call(self, graph_module: torch.fx.GraphModule):
313
338
dim_order = HWCM_ORDER
314
339
elif node_data .dim () == 5 :
315
340
dim_order = NNHWC_ORDER
341
+ elif node_data .dim () == 6 :
342
+ dim_order = NNNHWC_ORDER
316
343
else :
317
344
dim_order = tuple (range (node_data .dim ())) # type: ignore[assignment]
318
345
0 commit comments