1010from executorch .backends .arm ._passes .arm_pass_utils import (
1111 create_node ,
1212 get_first_fake_tensor ,
13+ is_param_node ,
1314)
1415from executorch .backends .arm .tosa_utils import is_consumer_node_depthwise_conv2d
16+ from executorch .exir import ExportedProgram
1517from executorch .exir .dialects ._ops import ops as exir_ops
1618from executorch .exir .pass_base import ExportPass , PassResult
1719from torch .library import impl , Library
@@ -40,7 +42,14 @@ def _transpose_impl(*args, **kwargs):
4042 return args [0 ]
4143
4244
43- class AnnotateChannelsLastDimOrder (ExportPass ):
45+ def _is_input (node : torch .fx .Node , exported_program : ExportedProgram ) -> bool :
46+ """
47+ Returns True if the node is an input node, i.e. a placeholder or a parameter.
48+ """
49+ return node .op == "placeholder" and not is_param_node (exported_program , node )
50+
51+
52+ class ToTosaMemoryFormatPass (ExportPass ):
4453 """
4554 Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
4655 that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose
@@ -54,6 +63,10 @@ class AnnotateChannelsLastDimOrder(ExportPass):
5463 NNHWC_order = (0 , 1 , 3 , 4 , 2 )
5564 NNHWC_inverse_order = (0 , 1 , 4 , 2 , 3 )
5665
66+ def __init__ (self , exported_program : ExportedProgram ) -> None :
67+ self .exported_program = exported_program
68+ super ().__init__ ()
69+
5770 def is_weight_node_for_depthwise_conv2d (self , node : torch .fx .Node ):
5871 """
5972 returns True for w in the following sequence;
@@ -70,6 +83,7 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
7083
7184 return False
7285
86+ @staticmethod
7387 @staticmethod
7488 def memory_format_differs (shape ):
7589 """Returns true if the shape will have a different memory layout in (N)NCHW and (N)NHWC format"""
@@ -116,25 +130,30 @@ def is_channel_reshape(input_shape, output_shape):
116130
117131 @staticmethod
118132 def insert_input_transpose (node , input_node , graph_module ):
133+ if input_node .target == torch .ops .passthrough_to_tosa ._transpose .default :
134+ pre_permute_node = input_node .all_input_nodes [0 ]
135+ node .replace_input_with (input_node , pre_permute_node )
136+ return
137+
119138 with graph_module .graph .inserting_before (node ):
120139 permute_node = create_node (
121140 graph_module .graph ,
122141 torch .ops .passthrough_to_tosa ._transpose .default ,
123142 args = (
124143 input_node ,
125144 list (
126- AnnotateChannelsLastDimOrder .NNHWC_inverse_order
145+ ToTosaMemoryFormatPass .NNHWC_inverse_order
127146 if len (get_first_fake_tensor (input_node ).size ()) == 5
128- else AnnotateChannelsLastDimOrder .NHWC_inverse_order
147+ else ToTosaMemoryFormatPass .NHWC_inverse_order
129148 ),
130149 ),
150+ from_node = node ,
131151 )
132152 node .replace_input_with (input_node , permute_node )
133153
134154 permute_node .meta ["tosa_dim_order" ] = tuple (
135155 range (len (input_node .meta ["val" ].size ()))
136156 )
137- permute_node .meta ["val" ] = input_node .meta ["val" ]
138157
139158 @staticmethod
140159 def insert_output_transpose (node , graph_module ):
@@ -145,25 +164,23 @@ def insert_output_transpose(node, graph_module):
145164 args = (
146165 node ,
147166 list (
148- AnnotateChannelsLastDimOrder .NNHWC_order
167+ ToTosaMemoryFormatPass .NNHWC_order
149168 if len (get_first_fake_tensor (node ).size ()) == 5
150- else AnnotateChannelsLastDimOrder .NHWC_order
169+ else ToTosaMemoryFormatPass .NHWC_order
151170 ),
152171 ),
172+ from_node = node ,
153173 )
174+
154175 permute_node .meta ["tosa_dim_order" ] = (
155- AnnotateChannelsLastDimOrder .NNHWC_order
176+ ToTosaMemoryFormatPass .NNHWC_order
156177 if len (get_first_fake_tensor (node ).size ()) == 5
157- else AnnotateChannelsLastDimOrder .NHWC_order
158- )
159- permute_node .meta ["val" ] = get_first_fake_tensor (node ).permute (
160- AnnotateChannelsLastDimOrder .NNHWC_order
161- if len (get_first_fake_tensor (node ).size ()) == 5
162- else AnnotateChannelsLastDimOrder .NHWC_order
178+ else ToTosaMemoryFormatPass .NHWC_order
163179 )
164180 node .meta ["tosa_dim_order" ] = tuple (
165181 range (len (get_first_fake_tensor (node ).size ()))
166182 )
183+
167184 users = [user for user in node .users if user != permute_node ]
168185 for user in users :
169186 user .replace_input_with (node , permute_node )
@@ -174,20 +191,23 @@ def _insert_view_transpose(
174191 ):
175192 nchw_to_nhwc = len (input_shape ) < 4 and len (output_shape ) >= 4
176193 nhwc_to_nchw = len (input_shape ) >= 4 and len (output_shape ) < 4
177- channel_reshape = AnnotateChannelsLastDimOrder .is_channel_reshape (
194+ channel_reshape = ToTosaMemoryFormatPass .is_channel_reshape (
178195 output_shape , input_shape
179196 )
180197
181198 if (
182199 channel_reshape or nhwc_to_nchw
183- ) and AnnotateChannelsLastDimOrder .memory_format_differs (input_shape ):
184- AnnotateChannelsLastDimOrder .insert_input_transpose (
200+ ) and ToTosaMemoryFormatPass .memory_format_differs (input_shape ):
201+
202+ ToTosaMemoryFormatPass .insert_input_transpose (
185203 node , input_node , graph_module
186204 )
205+
187206 if (
188207 channel_reshape or nchw_to_nhwc
189- ) and AnnotateChannelsLastDimOrder .memory_format_differs (output_shape ):
190- AnnotateChannelsLastDimOrder .insert_output_transpose (node , graph_module )
208+ ) and ToTosaMemoryFormatPass .memory_format_differs (output_shape ):
209+
210+ ToTosaMemoryFormatPass .insert_output_transpose (node , graph_module )
191211
192212 def insert_tosa_transposes (self , graph_module : torch .fx .GraphModule ):
193213 """
@@ -205,9 +225,10 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
205225 for node in graph_module .graph .nodes :
206226 # call_function and placeholder allowed due to
207227 # index.Tensor being able to come in as both
208- if node .op not in ["call_function" , "placeholder" ]:
228+ if node .op not in ["call_function" , "placeholder" , "output" ]:
209229 continue
210230
231+ # Transpose views
211232 elif node .target in (
212233 exir_ops .edge .aten .view_copy .default ,
213234 exir_ops .edge .aten .index .Tensor ,
@@ -218,25 +239,48 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
218239 input_node = node .args [0 ]
219240 input_shape = input_node .meta ["val" ].shape
220241 output_shape = node .meta ["val" ].shape
221-
222242 self ._insert_view_transpose (
223- input_shape , output_shape , node , input_node , graph_module
243+ input_shape ,
244+ output_shape ,
245+ node ,
246+ input_node ,
247+ graph_module ,
224248 )
225249
250+ # Transpose inputs
251+ elif _is_input (node , self .exported_program ):
252+ input_shape = get_first_fake_tensor (node ).size ()
253+ if len (input_shape ) in (4 , 5 ):
254+ ToTosaMemoryFormatPass .insert_output_transpose (node , graph_module )
255+
256+ # Transpose outputs
257+ elif node .op == "output" :
258+ output_shape = get_first_fake_tensor (node ).size ()
259+
260+ if len (output_shape ) in (4 , 5 ):
261+ for input_node in node .all_input_nodes :
262+ ToTosaMemoryFormatPass .insert_input_transpose (
263+ node , input_node , graph_module
264+ )
265+
226266 def call (self , graph_module : torch .fx .GraphModule ):
227267 for node in graph_module .graph .nodes :
228268 node_data = get_first_fake_tensor (node ).data
229269
230- if node_data .dim () == 4 :
270+ # Inputs and outputs are always in (N)NCHW format
271+ if _is_input (node , self .exported_program ) or node .op == "output" :
272+ dim_order = tuple (range (node_data .dim ()))
273+ elif node_data .dim () == 4 :
231274 dim_order = self .NHWC_order
232275 if self .is_weight_node_for_depthwise_conv2d (node ):
233276 # The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
234277 # dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
235278 dim_order = self .HWCM_order
236279 elif node_data .dim () == 5 :
237- dim_order = self .NNHWC_order # type: ignore[assignment]
280+ dim_order = self .NNHWC_order
238281 else :
239282 dim_order = tuple (range (node_data .dim ())) # type: ignore[assignment]
283+
240284 node .meta ["tosa_dim_order" ] = dim_order
241285 # Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format.
242286 # See insert_tosa_transposes for insertion conditions.
0 commit comments