5
5
# LICENSE file in the root directory of this source tree.
6
6
7
7
import torch
8
- import torch .nn as nn
9
8
from executorch .backends .qualcomm .builders .utils import get_parameter , set_parameter
10
9
from executorch .backends .qualcomm .utils .constants import QCOM_REQUANTIZE
11
- from executorch .exir .dialects ._ops import ops as exir_ops
12
10
from executorch .exir .pass_base import ExportPass , PassResult
13
11
14
12
from .utils import copy_meta
@@ -23,16 +21,43 @@ class ConvertConv1dToConv2d(ExportPass):
23
21
def __init__ (self , edge_program : torch .export .ExportedProgram ):
24
22
super (ConvertConv1dToConv2d , self ).__init__ ()
25
23
self .edge_program = edge_program
24
+ self .conv_op_map = {
25
+ torch .ops .aten .conv1d .default : torch .ops .aten .conv2d .default ,
26
+ torch .ops .aten .conv_transpose1d .default : torch .ops .aten .conv_transpose2d .input ,
27
+ }
28
+
29
+ def append_qdq (
30
+ self ,
31
+ graph_module : torch .fx .GraphModule ,
32
+ node : torch .fx .Node ,
33
+ qdq_node : torch .fx .Node ,
34
+ ):
35
+ q_op = torch .ops .quantized_decomposed .quantize_per_tensor .default
36
+ dq_op = torch .ops .quantized_decomposed .dequantize_per_tensor .default
37
+ if qdq_node .target not in {q_op , dq_op }:
38
+ return node
39
+
40
+ with graph_module .graph .inserting_after (node ):
41
+ q_args = (node , * qdq_node .args [1 :])
42
+ q_node = graph_module .graph .create_node ("call_function" , q_op , q_args )
43
+ q_node .meta = copy_meta (node .meta )
44
+ q_node .meta ["val" ] = q_node .meta ["val" ].to (q_args [- 1 ])
45
+ with graph_module .graph .inserting_after (q_node ):
46
+ dq_args = (q_node , * qdq_node .args [1 :])
47
+ dq_node = graph_module .graph .create_node (
48
+ "call_function" , dq_op , dq_args
49
+ )
50
+ dq_node .meta = copy_meta (node .meta )
51
+
52
+ return dq_node
26
53
27
54
def call (self , graph_module : torch .fx .GraphModule ):
28
55
graph = graph_module .graph
29
- conv_op = exir_ops .edge .aten .convolution .default
30
56
for node in graph .nodes :
31
- if node .target == conv_op and node .meta ["val" ].dim () == 3 :
32
-
57
+ if node .target in self .conv_op_map :
33
58
input_node = node .args [0 ]
34
59
with graph_module .graph .inserting_after (input_node ):
35
- unsqueeze_op = exir_ops . edge .aten .unsqueeze_copy .default
60
+ unsqueeze_op = torch . ops .aten .unsqueeze_copy .default
36
61
unsqueeze_node = graph .create_node (
37
62
"call_function" ,
38
63
unsqueeze_op ,
@@ -44,52 +69,88 @@ def call(self, graph_module: torch.fx.GraphModule):
44
69
unsqueeze_node .meta = copy_meta (
45
70
input_node .meta , lambda m : {** m , "val" : m ["val" ].unsqueeze (2 )}
46
71
)
72
+ qdq_node_after_unsqueeze = self .append_qdq (
73
+ graph_module = graph_module ,
74
+ node = unsqueeze_node ,
75
+ qdq_node = input_node ,
76
+ )
47
77
48
- with graph_module .graph .inserting_after (unsqueeze_node ):
49
-
50
- filter_node = node .args [1 ]
78
+ with graph_module .graph .inserting_after (qdq_node_after_unsqueeze ):
79
+ filter_arg = node .args [1 ]
80
+ filter_node = (
81
+ filter_arg
82
+ if filter_arg .op == "placeholder"
83
+ else node .args [1 ].args [0 ]
84
+ )
51
85
filter_node .meta ["val" ] = (
52
86
filter_node .meta ["val" ].unsqueeze (2 ).contiguous ()
53
87
)
54
- filter_tensor = get_parameter (filter_node , self .edge_program )
55
- # Ensure tensor is nn.Parameter type, so program does not fail during edge_program._validate()
56
- filter_tensor = nn .Parameter (filter_tensor .unsqueeze (2 ))
57
- set_parameter (filter_tensor , filter_node , self .edge_program )
88
+ filter_tensor = get_parameter (
89
+ filter_node , self .edge_program
90
+ ).unsqueeze (2 )
91
+ set_parameter (
92
+ (
93
+ torch .nn .Parameter (filter_tensor )
94
+ if filter_tensor .dtype == torch .float
95
+ else filter_tensor
96
+ ),
97
+ filter_node ,
98
+ self .edge_program ,
99
+ )
58
100
101
+ num_args = len (node .args )
59
102
bias_node = node .args [2 ]
60
- stride = [1 ] + node .args [3 ]
61
- padding = [0 ] + node .args [4 ]
62
- dilation = [1 ] + node .args [5 ]
63
- transpose = node .args [6 ]
64
- output_padding = [0 ] + node .args [7 ]
65
- groups = node .args [8 ]
66
-
67
- conv2d_node = graph .create_node (
68
- "call_function" ,
69
- conv_op ,
70
- (
71
- unsqueeze_node ,
72
- filter_node ,
103
+ stride = [1 ] + node .args [3 ] if num_args > 3 else [1 , 1 ]
104
+ padding = [0 ] + node .args [4 ] if num_args > 4 else [0 , 0 ]
105
+ if node .target == torch .ops .aten .conv1d .default :
106
+ dilation = [1 ] + node .args [5 ] if num_args > 5 else [1 , 1 ]
107
+ groups = node .args [6 ] if num_args > 5 else 1
108
+ conv_args = (
109
+ qdq_node_after_unsqueeze ,
110
+ node .args [1 ],
73
111
bias_node ,
74
112
stride ,
75
113
padding ,
76
114
dilation ,
77
- transpose ,
115
+ groups ,
116
+ )
117
+ else :
118
+ output_padding = (
119
+ [0 ] + node .args [5 ] if num_args > 5 else [0 , 0 ]
120
+ )
121
+ groups = node .args [6 ] if num_args > 6 else 1
122
+ dilation = [1 ] + node .args [7 ] if num_args > 7 else [1 , 1 ]
123
+ conv_args = (
124
+ qdq_node_after_unsqueeze ,
125
+ node .args [1 ],
126
+ bias_node ,
127
+ stride ,
128
+ padding ,
78
129
output_padding ,
79
130
groups ,
80
- ),
131
+ dilation ,
132
+ )
133
+ conv2d_node = graph .create_node (
134
+ "call_function" ,
135
+ self .conv_op_map [node .target ],
136
+ conv_args ,
81
137
)
82
138
conv2d_node .meta = copy_meta (
83
139
node .meta , lambda m : {** m , "val" : m ["val" ].unsqueeze (2 )}
84
140
)
141
+ qdq_node_after_conv2d = self .append_qdq (
142
+ graph_module = graph_module ,
143
+ node = conv2d_node ,
144
+ qdq_node = list (node .users )[0 ],
145
+ )
85
146
86
- with graph_module .graph .inserting_after (conv2d_node ):
87
- squeeze_op = exir_ops . edge .aten .squeeze_copy .dims
147
+ with graph_module .graph .inserting_after (qdq_node_after_conv2d ):
148
+ squeeze_op = torch . ops .aten .squeeze_copy .dims
88
149
squeeze_node = graph .create_node (
89
150
"call_function" ,
90
151
squeeze_op ,
91
152
(
92
- conv2d_node ,
153
+ qdq_node_after_conv2d ,
93
154
[2 ],
94
155
),
95
156
)
@@ -102,8 +163,10 @@ def call(self, graph_module: torch.fx.GraphModule):
102
163
QCOM_REQUANTIZE
103
164
]
104
165
conv2d_node .meta .pop (QCOM_REQUANTIZE , None )
166
+
105
167
for user in node .users .copy ():
106
168
user .replace_input_with (node , squeeze_node )
169
+
107
170
graph .eliminate_dead_code ()
108
171
graph_module .recompile ()
109
172
return PassResult (graph_module , True )
0 commit comments