1313from executorch .backends .arm ._passes .decompose_sum_pass import DecomposeSumPass
1414from executorch .backends .arm ._passes .fuse_constant_ops_pass import ComputeConstantOpsAOT
1515from executorch .backends .arm ._passes .size_adjust_input_pass import SizeAdjustInputPass
16+ from executorch .backends .arm .constants import DQ_OPS , Q_OPS
1617from executorch .exir .backend .utils import WhyNoPartitionReporter
1718from executorch .exir .dialects ._ops import ops as exir_ops
1819from executorch .exir .pass_base import ExportPass
@@ -50,6 +51,15 @@ def get_view(op):
5051 raise RuntimeError (f"Can't get meandim decomposition for op { op } " )
5152
5253
54+ def get_quantization (op ):
55+ """Returns quant and dequant op of same type (per_channel/ tensor) as op if op is a dequant node, None otherwise."""
56+ if op in DQ_OPS :
57+ # Input of op can be placeholder, can't use that to get quant node directly.
58+ quant_type_index = DQ_OPS .index (op )
59+ return Q_OPS [quant_type_index ], op
60+ return None
61+
62+
5363class DecomposeMeanDimPass (ArmPass ):
5464 """
5565 Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for:
@@ -121,6 +131,7 @@ def call_operator(self, op, args, kwargs, meta):
121131 dims_to_reduce = [dim - 1 for dim in dims_to_reduce ]
122132
123133 x = super ().call_operator (view_op , (x , new_shape ), {}, meta , True )
134+ x = self ._maybe_insert_q_dq_after (x , meta )
124135
125136 # Reduce (h,w) dims by avg pool if possible
126137 x , dims_to_reduce = self ._reduce_by_average_pool (op , x , dims_to_reduce , meta )
@@ -133,7 +144,7 @@ def call_operator(self, op, args, kwargs, meta):
133144 dims_to_reduce = [dim + len (original_dims ) - 1 for dim in dims_to_reduce ]
134145
135146 x = super ().call_operator (view_op , (x , temp_shape ), {}, meta , True )
136-
147+ x = self . _maybe_insert_q_dq_after ( x , meta )
137148 # Reduce remaining dims by sum
138149 x = self ._reduce_by_sum (op , x , dims_to_reduce , meta , dtype )
139150
@@ -156,6 +167,45 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype):
156167 full = super ().call_operator (
157168 full_op , ([1 ] * len (output_shape ), 1 / N ), {"dtype" : dtype }, meta , True
158169 )
170+ if (quant_ops := get_quantization (input_node .node .target )) is not None :
171+ # Insert Q and DQ nodes after full op.
172+ # Since the value of full is known, we can compute quant params such that dq(q_max_value)
173+ q_op , dq_op = quant_ops
174+ qmax = input_node .node .args [4 ]
175+ full_quant_args = (
176+ 1 / (N * qmax ), # Scale to map qmax to 1/N
177+ 0 , # Zero point
178+ * input_node .node .args [3 :],
179+ )
180+ q_args = (full , * full_quant_args )
181+ full = super ().call_operator (
182+ q_op ,
183+ q_args ,
184+ kwargs = {},
185+ meta = meta ,
186+ updated = True ,
187+ )
188+ dq_args = (full , * full_quant_args )
189+ full = super ().call_operator (
190+ dq_op , dq_args , kwargs = {}, meta = meta , updated = True
191+ )
192+
193+ # Insert Q and DQ nodes after sum op.
194+ # Scale needs to be adjusted with N, since it was computed on data after the division with N.
195+ sum_quant_args = (input_node .node .args [1 ] * N , * input_node .node .args [2 :])
196+ q_args = (sum , * sum_quant_args )
197+ sum = super ().call_operator (
198+ q_op ,
199+ q_args ,
200+ kwargs = {},
201+ meta = meta ,
202+ updated = True ,
203+ )
204+ dq_args = (sum , * sum_quant_args )
205+ sum = super ().call_operator (
206+ dq_op , dq_args , kwargs = {}, meta = meta , updated = True
207+ )
208+
159209 return super ().call_operator (mul_op , (sum , full ), {}, meta , True )
160210
161211 def _reduce_by_average_pool (self , op , input_node , dims , meta ):
@@ -190,10 +240,38 @@ def _reduce_by_average_pool(self, op, input_node, dims, meta):
190240 )
191241
192242 if is_supported :
243+ out = super ().call_operator (avgpool_op , args , {}, meta , True )
244+ out = self ._maybe_insert_q_dq_after (out , meta )
193245 return (
194- super (). call_operator ( avgpool_op , args , {}, meta , True ) ,
246+ out ,
195247 dims_to_reduce_by_sum ,
196248 )
197249
198250 else :
199251 return input_node , dims
252+
253+ def _maybe_insert_q_dq_after (self , op , meta ):
254+ """If the input node of op is a dequant node, insert a q-dq pair after op with identical quantization parameters."""
255+
256+ if len (op .node .all_input_nodes ) > 1 :
257+ raise ValueError (
258+ f"Expected one input to { op .node } , got inputs { op .node .all_input_nodes } "
259+ )
260+ input_node = op .node .all_input_nodes [0 ]
261+ if (quant_ops := get_quantization (input_node .target )) is not None :
262+ q_op , dq_op = quant_ops
263+ quant_args = list (input_node .args [1 :])
264+ q_args = (op , * quant_args )
265+ out = super ().call_operator (
266+ q_op ,
267+ q_args ,
268+ kwargs = {},
269+ meta = meta ,
270+ updated = True ,
271+ )
272+ dq_args = (out , * quant_args )
273+ return super ().call_operator (
274+ dq_op , dq_args , kwargs = {}, meta = meta , updated = True
275+ )
276+ else :
277+ return op
0 commit comments