11# Copyright 2024-2025 Arm Limited and/or its affiliates.
2- # All rights reserved.
32#
43# This source code is licensed under the BSD-style license found in the
54# LICENSE file in the root directory of this source tree.
3635def _transpose_impl (* args , ** kwargs ):
3736 # Validate length of dim_order array
3837 dim = args [1 ]
39- assert len (dim ) <= 4
38+ assert len (dim ) in ( 4 , 5 )
4039 # Pass-through in edge-IR
4140 return args [0 ]
4241
@@ -45,13 +44,15 @@ class AnnotateChannelsLastDimOrder(ExportPass):
4544 """
4645 Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
4746 that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose
48- when a transition between 3D and 4D tensors happen.
47+ when a transition between 3D and 4D/5D tensors happen.
4948 The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
5049 """
5150
5251 NHWC_order = (0 , 2 , 3 , 1 )
5352 NHWC_inverse_order = (0 , 3 , 1 , 2 )
5453 HWCM_order = (2 , 3 , 0 , 1 )
54+ NNHWC_order = (0 , 1 , 3 , 4 , 2 )
55+ NNHWC_inverse_order = (0 , 1 , 4 , 2 , 3 )
5556
5657 def is_weight_node_for_depthwise_conv2d (self , node : torch .fx .Node ):
5758 """
@@ -81,8 +82,12 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
8182
8283 @staticmethod
8384 def memory_format_differs (shape ):
84- """Returns true if the shape will have a different memory layout in NCHW and NHWC format"""
85- if len (shape ) >= 4 :
85+ """Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format"""
86+ if len (shape ) >= 5 :
87+ C = shape [2 ]
88+ H = shape [3 ]
89+ W = shape [4 ]
90+ elif len (shape ) == 4 :
8691 C = shape [1 ]
8792 H = shape [2 ]
8893 W = shape [3 ]
@@ -98,14 +103,24 @@ def memory_format_differs(shape):
98103 @staticmethod
99104 def is_channel_reshape (input_shape , output_shape ):
100105 """Returns true if the reshape changes the channel dimension"""
101- if not len (input_shape ) == len (output_shape ) == 4 :
106+ if not (
107+ (len (input_shape ) == len (output_shape ) and (len (output_shape ) in (4 , 5 )))
108+ or (len (input_shape ) == 4 and len (output_shape ) == 5 )
109+ or (len (input_shape ) == 5 and len (output_shape ) == 4 )
110+ ):
102111 return False
103112
104- C_old = input_shape [1 ]
105- C_new = output_shape [1 ]
113+ C_old = input_shape [- 3 ]
114+ C_new = output_shape [- 3 ]
106115
107- N_new = output_shape [0 ]
108- N_old = input_shape [0 ]
116+ N_new = (
117+ output_shape [0 ]
118+ if len (output_shape ) == 4
119+ else output_shape [0 ] * output_shape [1 ]
120+ )
121+ N_old = (
122+ input_shape [0 ] if len (input_shape ) == 4 else input_shape [0 ] * input_shape [1 ]
123+ )
109124
110125 return (N_old != N_new ) or (C_old != C_new )
111126
@@ -119,7 +134,11 @@ def insert_input_transpose(node, input_node, graph_module):
119134 torch .ops .passthrough_to_tosa ._transpose .default ,
120135 args = (
121136 input_node ,
122- list (AnnotateChannelsLastDimOrder .NHWC_inverse_order ),
137+ list (
138+ AnnotateChannelsLastDimOrder .NNHWC_inverse_order
139+ if len (get_first_fake_tensor (input_node ).size ()) == 5
140+ else AnnotateChannelsLastDimOrder .NHWC_inverse_order
141+ ),
123142 ),
124143 quantize = quantize ,
125144 q_params = q_params ,
@@ -137,15 +156,28 @@ def insert_output_transpose(node, graph_module):
137156 permute_node = create_node (
138157 graph_module .graph ,
139158 torch .ops .passthrough_to_tosa ._transpose .default ,
140- args = (node , list (AnnotateChannelsLastDimOrder .NHWC_order )),
159+ args = (
160+ node ,
161+ list (
162+ AnnotateChannelsLastDimOrder .NNHWC_order
163+ if len (get_first_fake_tensor (node ).size ()) == 5
164+ else AnnotateChannelsLastDimOrder .NHWC_order
165+ ),
166+ ),
141167 )
142168 permute_node .meta ["tosa_dim_order" ] = (
143- AnnotateChannelsLastDimOrder .NHWC_order
169+ AnnotateChannelsLastDimOrder .NNHWC_order
170+ if len (get_first_fake_tensor (node ).size ()) == 5
171+ else AnnotateChannelsLastDimOrder .NHWC_order
172+ )
173+ permute_node .meta ["val" ] = get_first_fake_tensor (node ).permute (
174+ AnnotateChannelsLastDimOrder .NNHWC_order
175+ if len (get_first_fake_tensor (node ).size ()) == 5
176+ else AnnotateChannelsLastDimOrder .NHWC_order
144177 )
145- permute_node .meta ["val " ] = node . meta [ "val" ]. permute (
146- AnnotateChannelsLastDimOrder . NHWC_order
178+ node .meta ["tosa_dim_order " ] = tuple (
179+ range ( len ( get_first_fake_tensor ( node ). size ()))
147180 )
148- node .meta ["tosa_dim_order" ] = (0 , 1 , 2 , 3 )
149181 users = [user for user in node .users if user != permute_node ]
150182 for user in users :
151183 user .replace_input_with (node , permute_node )
@@ -159,8 +191,8 @@ def insert_output_transpose(node, graph_module):
159191 def _insert_view_transpose (
160192 input_shape , output_shape , node , input_node , graph_module
161193 ):
162- nchw_to_nhwc = len (input_shape ) < 4 and len (output_shape ) = = 4
163- nhwc_to_nchw = len (input_shape ) = = 4 and len (output_shape ) < 4
194+ nchw_to_nhwc = len (input_shape ) < 4 and len (output_shape ) > = 4
195+ nhwc_to_nchw = len (input_shape ) > = 4 and len (output_shape ) < 4
164196 channel_reshape = AnnotateChannelsLastDimOrder .is_channel_reshape (
165197 output_shape , input_shape
166198 )
@@ -178,11 +210,11 @@ def _insert_view_transpose(
178210
179211 def insert_tosa_transposes (self , graph_module : torch .fx .GraphModule ):
180212 """
181- Transposes are needed for operators transforming the input to a different rank, as 4D-tensors are assumed to be in NHWC-format, whereas all other are in NCHW format.
213+ Transposes are needed for operators transforming the input to a different rank, as 4D and 5D -tensors are assumed to be in (N) NHWC-format, whereas all other are in (N) NCHW format.
182214 This is relevant for the following cases:
183- - view: <4D -> 4D
184- - view: 4D -> <4D
185- Additionally, a 4D->4D view operation acting on the channel dimension currently needs to be performed in NCHW format, leadning to one extra input and output transpose for this case.
215+ - view: <4D -> >= 4D
216+ - view: >= 4D -> <4D
217+ Additionally, a 4D/5D ->4D/5D view operation acting on the channel dimension currently needs to be performed in (N) NCHW format, leadning to one extra input and output transpose for this case.
186218
187219 Transposes can be avoided for shapes where there is no difference in actual memory, e.g for
188220 - H == W == 1
@@ -212,12 +244,13 @@ def call(self, graph_module: torch.fx.GraphModule):
212244 # The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
213245 # dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
214246 dim_order = self .HWCM_order
247+ elif node_data .dim () == 5 :
248+ dim_order = self .NNHWC_order # type: ignore[assignment]
215249 else :
216250 dim_order = tuple (range (node_data .dim ())) # type: ignore[assignment]
217251 node .meta ["tosa_dim_order" ] = dim_order
218- # Take care of cases when:
219- # 4D (NHWC) -> >4D (NCH)
220- # 3D (NCH) -> 4D (NHWC)
252+ # Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format.
253+ # See insert_tosa_transposes for insertion conditions.
221254 self .insert_tosa_transposes (graph_module )
222255 graph_module .recompile ()
223256 graph_module = super ().call (graph_module ).graph_module
0 commit comments