88import torch .nn as nn
99from executorch .backends .qualcomm .builders .utils import get_parameter , set_parameter
1010from executorch .backends .qualcomm .utils .constants import QCOM_REQUANTIZE
11- from executorch .exir .dialects ._ops import ops as exir_ops
1211from executorch .exir .pass_base import ExportPass , PassResult
1312
1413from .utils import copy_meta
@@ -23,16 +22,43 @@ class ConvertConv1dToConv2d(ExportPass):
2322 def __init__ (self , edge_program : torch .export .ExportedProgram ):
2423 super (ConvertConv1dToConv2d , self ).__init__ ()
2524 self .edge_program = edge_program
25+ self .conv_op_map = {
26+ torch .ops .aten .conv1d .default : torch .ops .aten .conv2d .default ,
27+ torch .ops .aten .conv_transpose1d .default : torch .ops .aten .conv_transpose2d .input ,
28+ }
29+
30+ def append_qdq (
31+ self ,
32+ graph_module : torch .fx .GraphModule ,
33+ node : torch .fx .Node ,
34+ qdq_node : torch .fx .Node ,
35+ ):
36+ q_op = torch .ops .quantized_decomposed .quantize_per_tensor .default
37+ dq_op = torch .ops .quantized_decomposed .dequantize_per_tensor .default
38+ if qdq_node .target not in {q_op , dq_op }:
39+ return node
40+
41+ with graph_module .graph .inserting_after (node ):
42+ q_args = (node , * qdq_node .args [1 :])
43+ q_node = graph_module .graph .create_node ("call_function" , q_op , q_args )
44+ q_node .meta = copy_meta (node .meta )
45+ q_node .meta ["val" ] = q_node .meta ["val" ].to (q_args [- 1 ])
46+ with graph_module .graph .inserting_after (q_node ):
47+ dq_args = (q_node , * qdq_node .args [1 :])
48+ dq_node = graph_module .graph .create_node (
49+ "call_function" , dq_op , dq_args
50+ )
51+ dq_node .meta = copy_meta (node .meta )
52+
53+ return dq_node
2654
2755 def call (self , graph_module : torch .fx .GraphModule ):
2856 graph = graph_module .graph
29- conv_op = exir_ops .edge .aten .convolution .default
3057 for node in graph .nodes :
31- if node .target == conv_op and node .meta ["val" ].dim () == 3 :
32-
58+ if node .target in self .conv_op_map :
3359 input_node = node .args [0 ]
3460 with graph_module .graph .inserting_after (input_node ):
35- unsqueeze_op = exir_ops . edge .aten .unsqueeze_copy .default
61+ unsqueeze_op = torch . ops .aten .unsqueeze_copy .default
3662 unsqueeze_node = graph .create_node (
3763 "call_function" ,
3864 unsqueeze_op ,
@@ -44,10 +70,19 @@ def call(self, graph_module: torch.fx.GraphModule):
4470 unsqueeze_node .meta = copy_meta (
4571 input_node .meta , lambda m : {** m , "val" : m ["val" ].unsqueeze (2 )}
4672 )
73+ qdq_node_after_unsqueeze = self .append_qdq (
74+ graph_module = graph_module ,
75+ node = unsqueeze_node ,
76+ qdq_node = input_node ,
77+ )
4778
48- with graph_module .graph .inserting_after (unsqueeze_node ):
49-
50- filter_node = node .args [1 ]
79+ with graph_module .graph .inserting_after (qdq_node_after_unsqueeze ):
80+ filter_arg = node .args [1 ]
81+ filter_node = (
82+ filter_arg
83+ if filter_arg .op == "placeholder"
84+ else node .args [1 ].args [0 ].args [0 ]
85+ )
5186 filter_node .meta ["val" ] = (
5287 filter_node .meta ["val" ].unsqueeze (2 ).contiguous ()
5388 )
@@ -56,40 +91,59 @@ def call(self, graph_module: torch.fx.GraphModule):
5691 filter_tensor = nn .Parameter (filter_tensor .unsqueeze (2 ))
5792 set_parameter (filter_tensor , filter_node , self .edge_program )
5893
94+ num_args = len (node .args )
5995 bias_node = node .args [2 ]
60- stride = [1 ] + node .args [3 ]
61- padding = [0 ] + node .args [4 ]
62- dilation = [1 ] + node .args [5 ]
63- transpose = node .args [6 ]
64- output_padding = [0 ] + node .args [7 ]
65- groups = node .args [8 ]
66-
67- conv2d_node = graph .create_node (
68- "call_function" ,
69- conv_op ,
70- (
71- unsqueeze_node ,
72- filter_node ,
96+ stride = [1 ] + node .args [3 ] if num_args > 3 else [1 , 1 ]
97+ padding = [0 ] + node .args [4 ] if num_args > 4 else [0 , 0 ]
98+ if node .target == torch .ops .aten .conv1d .default :
99+ dilation = [1 ] + node .args [5 ] if num_args > 5 else [1 , 1 ]
100+ groups = node .args [6 ] if num_args > 5 else 1
101+ conv_args = (
102+ qdq_node_after_unsqueeze ,
103+ node .args [1 ],
73104 bias_node ,
74105 stride ,
75106 padding ,
76107 dilation ,
77- transpose ,
108+ groups ,
109+ )
110+ else :
111+ output_padding = (
112+ [0 ] + node .args [5 ] if num_args > 5 else [0 , 0 ]
113+ )
114+ groups = node .args [6 ] if num_args > 6 else 1
115+ dilation = [1 ] + node .args [7 ] if num_args > 7 else [1 , 1 ]
116+ conv_args = (
117+ qdq_node_after_unsqueeze ,
118+ node .args [1 ],
119+ bias_node ,
120+ stride ,
121+ padding ,
78122 output_padding ,
79123 groups ,
80- ),
124+ dilation ,
125+ )
126+ conv2d_node = graph .create_node (
127+ "call_function" ,
128+ self .conv_op_map [node .target ],
129+ conv_args ,
81130 )
82131 conv2d_node .meta = copy_meta (
83132 node .meta , lambda m : {** m , "val" : m ["val" ].unsqueeze (2 )}
84133 )
134+ qdq_node_after_conv2d = self .append_qdq (
135+ graph_module = graph_module ,
136+ node = conv2d_node ,
137+ qdq_node = list (node .users )[0 ],
138+ )
85139
86- with graph_module .graph .inserting_after (conv2d_node ):
87- squeeze_op = exir_ops . edge .aten .squeeze_copy .dims
140+ with graph_module .graph .inserting_after (qdq_node_after_conv2d ):
141+ squeeze_op = torch . ops .aten .squeeze_copy .dims
88142 squeeze_node = graph .create_node (
89143 "call_function" ,
90144 squeeze_op ,
91145 (
92- conv2d_node ,
146+ qdq_node_after_conv2d ,
93147 [2 ],
94148 ),
95149 )
@@ -102,8 +156,10 @@ def call(self, graph_module: torch.fx.GraphModule):
102156 QCOM_REQUANTIZE
103157 ]
104158 conv2d_node .meta .pop (QCOM_REQUANTIZE , None )
159+
105160 for user in node .users .copy ():
106161 user .replace_input_with (node , squeeze_node )
162+
107163 graph .eliminate_dead_code ()
108164 graph_module .recompile ()
109165 return PassResult (graph_module , True )
0 commit comments