2525 NNCHW_ORDER ,
2626 NNHWC_INVERSE_ORDER ,
2727 NNHWC_ORDER ,
28+ NNNCHW_ORDER ,
29+ NNNHWC_INVERSE_ORDER ,
30+ NNNHWC_ORDER ,
2831)
2932from executorch .exir import ExportedProgram
3033from executorch .exir .dialects ._ops import ops as exir_ops
@@ -48,6 +51,8 @@ class ToTosaMemoryFormatPass(ExportPass):
4851 The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
4952 """
5053
54+ _passes_required_after : Set [Type [ExportPass ]] = set ()
55+
5156 def __init__ (self , exported_program : ExportedProgram ) -> None :
5257 self .exported_program = exported_program
5358 super ().__init__ ()
@@ -84,7 +89,11 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
8489 @staticmethod
8590 def memory_format_differs (shape ):
8691 """Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format"""
87- if len (shape ) >= 5 :
92+ if len (shape ) >= 6 :
93+ C = shape [3 ]
94+ H = shape [4 ]
95+ W = shape [5 ]
96+ elif len (shape ) == 5 :
8897 C = shape [2 ]
8998 H = shape [3 ]
9099 W = shape [4 ]
@@ -103,25 +112,26 @@ def memory_format_differs(shape):
103112
104113 @staticmethod
105114 def is_channel_reshape (input_shape , output_shape ):
106- """Returns true if the reshape changes the channel dimension"""
107- if not (
108- (len (input_shape ) == len (output_shape ) and (len (output_shape ) in (4 , 5 )))
109- or (len (input_shape ) == 4 and len (output_shape ) == 5 )
110- or (len (input_shape ) == 5 and len (output_shape ) == 4 )
111- ):
115+ """Returns true if reshape changes the channel dimension or batch product dimension(s)"""
116+
117+ valid_ranks = {4 , 5 , 6 }
118+
119+ if not (len (input_shape ) in valid_ranks and len (output_shape ) in valid_ranks ):
112120 return False
113121
114122 C_old = input_shape [- 3 ]
115123 C_new = output_shape [- 3 ]
116124
117- N_new = (
118- output_shape [0 ]
119- if len (output_shape ) == 4
120- else output_shape [0 ] * output_shape [1 ]
121- )
122- N_old = (
123- input_shape [0 ] if len (input_shape ) == 4 else input_shape [0 ] * input_shape [1 ]
124- )
125+ def get_batch_prod_dim (shape ):
126+ product = 1
127+
128+ for dim in shape [:- 3 ]:
129+ product = product * dim
130+
131+ return product
132+
133+ N_old = get_batch_prod_dim (input_shape )
134+ N_new = get_batch_prod_dim (output_shape )
125135
126136 return (N_old != N_new ) or (C_old != C_new )
127137
@@ -132,17 +142,27 @@ def insert_input_transpose(node, input_node, graph_module):
132142 node .replace_input_with (input_node , pre_permute_node )
133143 return
134144
145+ if len (get_first_fake_tensor (input_node ).size ()) == 6 :
146+ mem_format = NNNHWC_INVERSE_ORDER
147+ elif len (get_first_fake_tensor (input_node ).size ()) == 5 :
148+ mem_format = NNHWC_INVERSE_ORDER
149+ else :
150+ mem_format = NHWC_INVERSE_ORDER
151+ # Guard: mem_format must be a true permutation for the current rank
152+ _rank_ = len (
153+ get_first_fake_tensor (input_node ).size ()
154+ ) # or (node) in output path
155+ assert sorted (mem_format ) == list (
156+ range (_rank_ )
157+ ), f"bad perm { mem_format } for rank { _rank_ } in insert_input_transpose"
158+
135159 with graph_module .graph .inserting_before (node ):
136160 permute_node = create_node (
137161 graph_module .graph ,
138162 exir_ops .backend .tosa .TRANSPOSE .default ,
139163 args = (
140164 input_node ,
141- list (
142- NNHWC_INVERSE_ORDER
143- if len (get_first_fake_tensor (input_node ).size ()) == 5
144- else NHWC_INVERSE_ORDER
145- ),
165+ list (mem_format ),
146166 ),
147167 from_node = node ,
148168 )
@@ -154,26 +174,38 @@ def insert_input_transpose(node, input_node, graph_module):
154174
155175 @staticmethod
156176 def insert_output_transpose (node , graph_module ):
177+
178+ if len (get_first_fake_tensor (node ).size ()) == 6 :
179+ mem_format = NNNHWC_ORDER
180+ elif len (get_first_fake_tensor (node ).size ()) == 5 :
181+ mem_format = NNHWC_ORDER
182+ else :
183+ mem_format = NHWC_ORDER
184+ # Guard: mem_format must be a true permutation for the current rank
185+ _rank_ = len (get_first_fake_tensor (node ).size ()) # or (node) in output path
186+ assert sorted (mem_format ) == list (
187+ range (_rank_ )
188+ ), f"bad perm { mem_format } for rank { _rank_ } in insert_input_transpose"
189+
157190 with graph_module .graph .inserting_after (node ):
158191 permute_node = create_node (
159192 graph_module .graph ,
160193 exir_ops .backend .tosa .TRANSPOSE .default ,
161194 args = (
162195 node ,
163- list (
164- NNHWC_ORDER
165- if len (get_first_fake_tensor (node ).size ()) == 5
166- else NHWC_ORDER
167- ),
196+ list (mem_format ),
168197 ),
169198 from_node = node ,
170199 )
171200
172- permute_node .meta ["tosa_dim_order" ] = (
173- NNHWC_ORDER
174- if len (get_first_fake_tensor (node ).size ()) == 5
175- else NHWC_ORDER
176- )
201+ rank = len (get_first_fake_tensor (node ).size ())
202+ if rank == 6 :
203+ permute_node .meta ["tosa_dim_order" ] = NNNHWC_ORDER
204+ elif rank == 5 :
205+ permute_node .meta ["tosa_dim_order" ] = NNHWC_ORDER
206+ else :
207+ permute_node .meta ["tosa_dim_order" ] = NHWC_ORDER
208+
177209 node .meta ["tosa_dim_order" ] = tuple (
178210 range (len (get_first_fake_tensor (node ).size ()))
179211 )
@@ -252,7 +284,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
252284 ]
253285 for input_node in inputs :
254286 input_dim_order = get_first_fake_tensor (input_node ).dim_order ()
255- if input_dim_order in (NCHW_ORDER , NNCHW_ORDER ):
287+ if input_dim_order in (NCHW_ORDER , NNCHW_ORDER , NNNCHW_ORDER ):
256288 self .insert_output_transpose (input_node , graph_module )
257289
258290 # Transpose outputs if they are in (N)NCHW format
@@ -267,6 +299,7 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
267299 if output_dim_order in (
268300 NCHW_ORDER ,
269301 NNCHW_ORDER ,
302+ NNNCHW_ORDER ,
270303 ):
271304 self .insert_input_transpose (
272305 output_node , output_node_input , graph_module
@@ -304,6 +337,8 @@ def call(self, graph_module: torch.fx.GraphModule):
304337 dim_order = HWCM_ORDER
305338 elif node_data .dim () == 5 :
306339 dim_order = NNHWC_ORDER
340+ elif node_data .dim () == 6 :
341+ dim_order = NNNHWC_ORDER
307342 else :
308343 dim_order = tuple (range (node_data .dim ())) # type: ignore[assignment]
309344
0 commit comments