1212from executorch .backends .qualcomm .utils .constants import (
1313 QCOM_AXIS_ORDER ,
1414 QCOM_INSERTED_PERMUTE ,
15+ QCOM_LAYOUT_CHANGE ,
1516 QCOM_QUANT_ATTRS ,
1617 QCOM_REQUANTIZE ,
1718)
@@ -34,6 +35,7 @@ class LayoutTransform(ExportPass):
3435 exir_ops .edge .aten .convolution .default ,
3536 exir_ops .edge .aten .max_pool2d_with_indices .default ,
3637 exir_ops .edge .aten ._native_batch_norm_legit_no_training .default ,
38+ exir_ops .edge .aten .native_group_norm .default ,
3739 exir_ops .edge .aten .pixel_shuffle .default ,
3840 exir_ops .edge .aten .pixel_unshuffle .default ,
3941 exir_ops .edge .aten .upsample_bilinear2d .default ,
@@ -95,6 +97,7 @@ def __init__(
9597 self .edge_program = edge_program
9698 self .insert_permute = insert_permute
9799 self .qdq_opset = {* q_ops , * dq_ops }
100+ self .transformed_tag = QCOM_AXIS_ORDER
98101
99102 def mark_as_transformed (self , node : torch .fx .Node ) -> None :
100103 if isinstance (node .meta ["val" ], (tuple , list )):
@@ -105,18 +108,18 @@ def mark_as_transformed(self, node: torch.fx.Node) -> None:
105108 f"got { getitem_node .target .__name__ } "
106109 )
107110 index = getitem_node .args [1 ]
108- node .meta [QCOM_AXIS_ORDER ] = self .get_axis_order (
111+ node .meta [self . transformed_tag ] = self .get_axis_order (
109112 eval_shape (node .meta ["val" ][index ].shape )
110113 )
111114 else :
112- node .meta [QCOM_AXIS_ORDER ] = self .get_axis_order (
115+ node .meta [self . transformed_tag ] = self .get_axis_order (
113116 eval_shape (node .meta ["val" ].shape )
114117 )
115118
116119 def is_transformed_node (self , node : torch .fx .Node ) -> bool :
117120 if not hasattr (node , "meta" ):
118121 return False
119- return QCOM_AXIS_ORDER in node .meta
122+ return self . transformed_tag in node .meta
120123
121124 def is_layout_sensitive (self , node : torch .fx .Node ) -> bool :
122125 return node .target in self .layout_sensitive_ops
@@ -186,8 +189,23 @@ def insert_node(self, graph_module, node, revert_layout: bool) -> None:
186189 # we need this to check the annotation boundary
187190 permute .meta [QCOM_INSERTED_PERMUTE ] = True
188191
192+ # this is the case when residual connection happened:
193+ # e.g. consider following graph
194+ # x --> permute --> layer_norm --> permute --> conv2d --> add
195+ # └-------------------------------------┙
196+ # we should have premute node to be correctly inserted as:
197+ # x --> permute --> layer_norm --> permute --> qnn_permute --> conv2d --> add
198+ # └--------------------------------------> qnn_premute -┙
199+ # i.e. insert permute by condition between user and current node
200+ # if there are multiple users included
201+ is_node_transformed = self .is_transformed_node (node )
189202 for user in users :
190- user .replace_input_with (node , permute )
203+ is_user_transformed = (
204+ self .is_transformed_node (user ) or QCOM_LAYOUT_CHANGE in user .meta
205+ )
206+ # insert permute only in exclusive condition
207+ if is_node_transformed != is_user_transformed :
208+ user .replace_input_with (node , permute )
191209
192210 def create_call_function_node (
193211 self ,
@@ -243,6 +261,15 @@ def call(self, graph_module: torch.fx.GraphModule):
243261 sensitive_nodes = [
244262 node for node in graph .nodes if self .is_layout_sensitive (node )
245263 ]
264+ # perform first run traversal for identifying nodes subjected to layout changes
265+ if self .insert_permute :
266+ self .insert_permute , self .transformed_tag = False , QCOM_LAYOUT_CHANGE
267+ for node in sensitive_nodes :
268+ if not self .is_transformed_node (node ):
269+ self .mark_as_transformed (node )
270+ self .traverse (node , graph_module )
271+ self .insert_permute , self .transformed_tag = True , QCOM_AXIS_ORDER
272+
246273 for node in sensitive_nodes :
247274 if not self .is_transformed_node (node ):
248275 self .mark_as_transformed (node )
0 commit comments