2626 NNCHW_ORDER ,
2727 NNHWC_INVERSE_ORDER ,
2828 NNHWC_ORDER ,
29+ NNNCHW_ORDER ,
30+ NNNHWC_INVERSE_ORDER ,
31+ NNNHWC_ORDER ,
2932)
3033from executorch .exir import ExportedProgram
3134from executorch .exir .dialects ._ops import ops as exir_ops
@@ -51,12 +54,6 @@ class ToTosaMemoryFormatPass(ExportPass):
5154
5255 _passes_required_after : Set [Type [ExportPass ]] = set ()
5356
54- NHWC_order = (0 , 2 , 3 , 1 )
55- NHWC_inverse_order = (0 , 3 , 1 , 2 )
56- HWCM_order = (2 , 3 , 0 , 1 )
57- NNHWC_order = (0 , 1 , 3 , 4 , 2 )
58- NNHWC_inverse_order = (0 , 1 , 4 , 2 , 3 )
59-
6057 def __init__ (self , exported_program : ExportedProgram ) -> None :
6158 self .exported_program = exported_program
6259 super ().__init__ ()
@@ -93,7 +90,11 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
9390 @staticmethod
9491 def memory_format_differs (shape ):
9592 """Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format"""
96- if len (shape ) >= 5 :
93+ if len (shape ) >= 6 :
94+ C = shape [3 ]
95+ H = shape [4 ]
96+ W = shape [5 ]
97+ elif len (shape ) == 5 :
9798 C = shape [2 ]
9899 H = shape [3 ]
99100 W = shape [4 ]
@@ -112,25 +113,26 @@ def memory_format_differs(shape):
112113
113114 @staticmethod
114115 def is_channel_reshape (input_shape , output_shape ):
115- """Returns true if the reshape changes the channel dimension"""
116- if not (
117- (len (input_shape ) == len (output_shape ) and (len (output_shape ) in (4 , 5 )))
118- or (len (input_shape ) == 4 and len (output_shape ) == 5 )
119- or (len (input_shape ) == 5 and len (output_shape ) == 4 )
120- ):
116+ """Returns true if reshape changes the channel dimension or batch product dimension(s)"""
117+
118+ valid_ranks = {4 , 5 , 6 }
119+
120+ if not (len (input_shape ) in valid_ranks and len (output_shape ) in valid_ranks ):
121121 return False
122122
123123 C_old = input_shape [- 3 ]
124124 C_new = output_shape [- 3 ]
125125
126- N_new = (
127- output_shape [0 ]
128- if len (output_shape ) == 4
129- else output_shape [0 ] * output_shape [1 ]
130- )
131- N_old = (
132- input_shape [0 ] if len (input_shape ) == 4 else input_shape [0 ] * input_shape [1 ]
133- )
126+ def get_batch_prod_dim (shape ):
127+ product = 1
128+
129+ for dim in shape [:- 3 ]:
130+ product = product * dim
131+
132+ return product
133+
134+ N_old = get_batch_prod_dim (input_shape )
135+ N_new = get_batch_prod_dim (output_shape )
134136
135137 return (N_old != N_new ) or (C_old != C_new )
136138
@@ -141,17 +143,27 @@ def insert_input_transpose(node, input_node, graph_module):
141143 node .replace_input_with (input_node , pre_permute_node )
142144 return
143145
146+ if len (get_first_fake_tensor (input_node ).size ()) == 6 :
147+ mem_format = NNNHWC_INVERSE_ORDER
148+ elif len (get_first_fake_tensor (input_node ).size ()) == 5 :
149+ mem_format = NNHWC_INVERSE_ORDER
150+ else :
151+ mem_format = NHWC_INVERSE_ORDER
152+ # Guard: mem_format must be a true permutation for the current rank
153+ _rank_ = len (
154+ get_first_fake_tensor (input_node ).size ()
155+ ) # or (node) in output path
156+ assert sorted (mem_format ) == list (
157+ range (_rank_ )
158+ ), f"bad perm { mem_format } for rank { _rank_ } in insert_input_transpose"
159+
144160 with graph_module .graph .inserting_before (node ):
145161 permute_node = create_node (
146162 graph_module .graph ,
147163 exir_ops .backend .tosa .TRANSPOSE .default ,
148164 args = (
149165 input_node ,
150- list (
151- NNHWC_INVERSE_ORDER
152- if len (get_first_fake_tensor (input_node ).size ()) == 5
153- else NHWC_INVERSE_ORDER
154- ),
166+ list (mem_format ),
155167 ),
156168 from_node = node ,
157169 )
@@ -163,26 +175,38 @@ def insert_input_transpose(node, input_node, graph_module):
163175
164176 @staticmethod
165177 def insert_output_transpose (node , graph_module ):
178+
179+ if len (get_first_fake_tensor (node ).size ()) == 6 :
180+ mem_format = NNNHWC_ORDER
181+ elif len (get_first_fake_tensor (node ).size ()) == 5 :
182+ mem_format = NNHWC_ORDER
183+ else :
184+ mem_format = NHWC_ORDER
185+ # Guard: mem_format must be a true permutation for the current rank
186+ _rank_ = len (get_first_fake_tensor (node ).size ()) # or (node) in output path
187+ assert sorted (mem_format ) == list (
188+ range (_rank_ )
189+ ), f"bad perm { mem_format } for rank { _rank_ } in insert_input_transpose"
190+
166191 with graph_module .graph .inserting_after (node ):
167192 permute_node = create_node (
168193 graph_module .graph ,
169194 exir_ops .backend .tosa .TRANSPOSE .default ,
170195 args = (
171196 node ,
172- list (
173- NNHWC_ORDER
174- if len (get_first_fake_tensor (node ).size ()) == 5
175- else NHWC_ORDER
176- ),
197+ list (mem_format ),
177198 ),
178199 from_node = node ,
179200 )
180201
181- permute_node .meta ["tosa_dim_order" ] = (
182- NNHWC_ORDER
183- if len (get_first_fake_tensor (node ).size ()) == 5
184- else NHWC_ORDER
185- )
202+ rank = len (get_first_fake_tensor (node ).size ())
203+ if rank == 6 :
204+ permute_node .meta ["tosa_dim_order" ] = NNNHWC_ORDER
205+ elif rank == 5 :
206+ permute_node .meta ["tosa_dim_order" ] = NNHWC_ORDER
207+ else :
208+ permute_node .meta ["tosa_dim_order" ] = NHWC_ORDER
209+
186210 node .meta ["tosa_dim_order" ] = tuple (
187211 range (len (get_first_fake_tensor (node ).size ()))
188212 )
@@ -261,7 +285,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
261285 ]
262286 for input_node in inputs :
263287 input_dim_order = get_first_fake_tensor (input_node ).dim_order ()
264- if input_dim_order in (NCHW_ORDER , NNCHW_ORDER ):
288+ if input_dim_order in (NCHW_ORDER , NNCHW_ORDER , NNNCHW_ORDER ):
265289 self .insert_output_transpose (input_node , graph_module )
266290
267291 # Transpose outputs if they are in (N)NCHW format
@@ -276,6 +300,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
276300 if output_dim_order in (
277301 NCHW_ORDER ,
278302 NNCHW_ORDER ,
303+ NNNCHW_ORDER ,
279304 ):
280305 self .insert_input_transpose (
281306 output_node , output_node_input , graph_module
@@ -313,6 +338,8 @@ def call(self, graph_module: torch.fx.GraphModule):
313338 dim_order = HWCM_ORDER
314339 elif node_data .dim () == 5 :
315340 dim_order = NNHWC_ORDER
341+ elif node_data .dim () == 6 :
342+ dim_order = NNNHWC_ORDER
316343 else :
317344 dim_order = tuple (range (node_data .dim ())) # type: ignore[assignment]
318345
0 commit comments