99from typing import cast
1010
1111import torch
12- from executorch .backends .arm ._passes .arm_pass_utils import get_first_fake_tensor
12+ from executorch .backends .arm ._passes .arm_pass_utils import (
13+ create_node ,
14+ get_first_fake_tensor ,
15+ )
1316from executorch .backends .arm .tosa_quant_utils import dq_op
1417from executorch .backends .arm .tosa_utils import is_consumer_node_depthwise_conv2d
18+ from executorch .exir .dialects ._ops import ops as exir_ops
1519from executorch .exir .pass_base import ExportPass , PassResult
20+ from torch .library import impl , Library
21+
22+ # Define lib with passthrough operators. The operators have no real meaning in edge IR
23+ # except for argument validaiton and a passthrough output. The operators will be used
24+ # when lowering to TOSA, e.g. a passthrough_to_tosa._transpose will not affect
25+ # the edge IR graph but will be lowered to a TOSA-TRANSPOSE.
26+ lib = Library ("passthrough_to_tosa" , "DEF" )
27+ # For operators that change the rank of the input, such as unsqueeze and squeeze, we may need
28+ # to switch dim_order before the opertation. Changing tosa_dim_order is not sufficient
29+ # as we also need transpose the data into the correct data format.
30+ # By utilizing an edge IR passthrough operator we can keep the edge program in
31+ # channels-first/contiguous and get the desired behavior in the TOSA lowering.
32+ lib .define ("_transpose(Tensor self, int[] dim_order) -> Tensor" )
33+
34+
35+ @impl (lib , "_transpose" )
36+ def _transpose_impl (* args , ** kwargs ):
37+ # Validate length of dim_order array
38+ dim = args [1 ]
39+ assert len (dim ) <= 4
40+ # Pass-through in edge-IR
41+ return args [0 ]
1642
1743
1844class AnnotateChannelsLastDimOrder (ExportPass ):
1945 """
2046 Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
21- that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes.
22- The annotated tosa_dim_order is used to permute the node's shape such that it
23- gives a TOSA-compliant shape.
47+ that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts passthrough_to_tosa._transpose
48+ when a transition between 3D and 4D tensors happen.
49+ The annotated tosa_dim_order is used to permute the node's shape such that it gives a TOSA-compliant shape.
2450 """
2551
52+ NHWC_order = (0 , 2 , 3 , 1 )
53+ NHWC_inverse_order = (0 , 3 , 1 , 2 )
54+ HWCM_order = (2 , 3 , 0 , 1 )
55+
2656 def is_weight_node_for_depthwise_conv2d (self , node : torch .fx .Node ):
2757 """
2858 returns True for dq and w in the following sequences;
@@ -49,20 +79,56 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
4979
5080 return False
5181
82+ def insert_tosa_transposes (self , graph_module : torch .fx .GraphModule ):
83+ for node in graph_module .graph .nodes :
84+ if node .op != "call_function" :
85+ continue
86+ if node .target == exir_ops .edge .aten .squeeze_copy .dims :
87+ input_node = node .args [0 ]
88+ if input_node .meta ["val" ].dim () == 4 :
89+ with graph_module .graph .inserting_before (node ):
90+ permute_node = create_node (
91+ graph_module .graph ,
92+ torch .ops .passthrough_to_tosa ._transpose ,
93+ args = (input_node , list (self .NHWC_inverse_order )),
94+ )
95+ permute_node .meta ["tosa_dim_order" ] = tuple (
96+ range (len (input_node .meta ["val" ].size ()))
97+ )
98+ node .replace_input_with (input_node , permute_node )
99+
100+ if node .target == exir_ops .edge .aten .unsqueeze_copy .default :
101+ if node .meta ["val" ].dim () == 4 :
102+ with graph_module .graph .inserting_after (node ):
103+ permute_node = create_node (
104+ graph_module .graph ,
105+ torch .ops .passthrough_to_tosa ._transpose ,
106+ args = (node , list (self .NHWC_order )),
107+ )
108+ permute_node .meta ["tosa_dim_order" ] = self .NHWC_order
109+ node .meta ["tosa_dim_order" ] = (0 , 1 , 2 , 3 )
110+ users = [user for user in node .users if user != permute_node ]
111+ for user in users :
112+ user .replace_input_with (node , permute_node )
113+
52114 def call (self , graph_module : torch .fx .GraphModule ):
53- NHWC_Order = (0 , 2 , 3 , 1 )
54- HWCM_Order = (2 , 3 , 0 , 1 )
55115 for node in graph_module .graph .nodes :
56116 node_data = get_first_fake_tensor (node ).data
57117
58- if len ( node_data .shape ) == 4 :
59- dim_order = NHWC_Order
118+ if node_data .dim ( ) == 4 :
119+ dim_order = self . NHWC_order
60120 if self .is_weight_node_for_depthwise_conv2d (node ):
61121 # The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
62122 # dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
63- dim_order = HWCM_Order
123+ dim_order = self . HWCM_order
64124 else :
65125 dim_order = tuple (range (node_data .dim ()))
66126 node .meta ["tosa_dim_order" ] = dim_order
127+ # Take care of cases when:
128+ # 4D (NHWC) -> >4D (NCH)
129+ # 3D (NCH) -> 4D (NHWC)
130+ self .insert_tosa_transposes (graph_module )
67131 graph_module .recompile ()
132+ graph_module = super ().call (graph_module ).graph_module
133+
68134 return PassResult (graph_module , True )
0 commit comments