diff --git a/backends/arm/_passes/decompose_meandim_pass.py b/backends/arm/_passes/decompose_meandim_pass.py index 2ec7497ae82..1360fc44f98 100644 --- a/backends/arm/_passes/decompose_meandim_pass.py +++ b/backends/arm/_passes/decompose_meandim_pass.py @@ -13,6 +13,7 @@ from executorch.backends.arm._passes.decompose_sum_pass import DecomposeSumPass from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT from executorch.backends.arm._passes.size_adjust_input_pass import SizeAdjustInputPass +from executorch.backends.arm.constants import DQ_OPS, Q_OPS from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -50,6 +51,15 @@ def get_view(op): raise RuntimeError(f"Can't get meandim decomposition for op {op}") +def get_quantization(op): + """Returns quant and dequant op of same type (per_channel/ tensor) as op if op is a dequant node, None otherwise.""" + if op in DQ_OPS: + # Input of op can be placeholder, can't use that to get quant node directly. + quant_type_index = DQ_OPS.index(op) + return Q_OPS[quant_type_index], op + return None + + class DecomposeMeanDimPass(ArmPass): """ Decomposes a meandim into avg_pool and/or sum + mul (1/N) depending on which dims the mean is taken for: @@ -121,6 +131,7 @@ def call_operator(self, op, args, kwargs, meta): dims_to_reduce = [dim - 1 for dim in dims_to_reduce] x = super().call_operator(view_op, (x, new_shape), {}, meta, True) + x = self._maybe_insert_q_dq_after(x, meta) # Reduce (h,w) dims by avg pool if possible x, dims_to_reduce = self._reduce_by_average_pool(op, x, dims_to_reduce, meta) @@ -133,7 +144,7 @@ def call_operator(self, op, args, kwargs, meta): dims_to_reduce = [dim + len(original_dims) - 1 for dim in dims_to_reduce] x = super().call_operator(view_op, (x, temp_shape), {}, meta, True) - + x = self._maybe_insert_q_dq_after(x, meta) # Reduce remaining dims by sum x = self._reduce_by_sum(op, x, dims_to_reduce, meta, dtype) @@ -156,6 +167,45 @@ def _reduce_by_sum(self, op, input_node, dims, meta, dtype): full = super().call_operator( full_op, ([1] * len(output_shape), 1 / N), {"dtype": dtype}, meta, True ) + if (quant_ops := get_quantization(input_node.node.target)) is not None: + # Insert Q and DQ nodes after full op. + # Since the value of full is known, we can compute quant params such that dq(q_max_value) + q_op, dq_op = quant_ops + qmax = input_node.node.args[4] + full_quant_args = ( + 1 / (N * qmax), # Scale to map qmax to 1/N + 0, # Zero point + *input_node.node.args[3:], + ) + q_args = (full, *full_quant_args) + full = super().call_operator( + q_op, + q_args, + kwargs={}, + meta=meta, + updated=True, + ) + dq_args = (full, *full_quant_args) + full = super().call_operator( + dq_op, dq_args, kwargs={}, meta=meta, updated=True + ) + + # Insert Q and DQ nodes after sum op. + # Scale needs to be adjusted with N, since it was computed on data after the division with N. + sum_quant_args = (input_node.node.args[1] * N, *input_node.node.args[2:]) + q_args = (sum, *sum_quant_args) + sum = super().call_operator( + q_op, + q_args, + kwargs={}, + meta=meta, + updated=True, + ) + dq_args = (sum, *sum_quant_args) + sum = super().call_operator( + dq_op, dq_args, kwargs={}, meta=meta, updated=True + ) + return super().call_operator(mul_op, (sum, full), {}, meta, True) def _reduce_by_average_pool(self, op, input_node, dims, meta): @@ -190,10 +240,38 @@ def _reduce_by_average_pool(self, op, input_node, dims, meta): ) if is_supported: + out = super().call_operator(avgpool_op, args, {}, meta, True) + out = self._maybe_insert_q_dq_after(out, meta) return ( - super().call_operator(avgpool_op, args, {}, meta, True), + out, dims_to_reduce_by_sum, ) else: return input_node, dims + + def _maybe_insert_q_dq_after(self, op, meta): + """If the input node of op is a dequant node, insert a q-dq pair after op with identical quantization parameters.""" + + if len(op.node.all_input_nodes) > 1: + raise ValueError( + f"Expected one input to {op.node}, got inputs {op.node.all_input_nodes}" + ) + input_node = op.node.all_input_nodes[0] + if (quant_ops := get_quantization(input_node.target)) is not None: + q_op, dq_op = quant_ops + quant_args = list(input_node.args[1:]) + q_args = (op, *quant_args) + out = super().call_operator( + q_op, + q_args, + kwargs={}, + meta=meta, + updated=True, + ) + dq_args = (out, *quant_args) + return super().call_operator( + dq_op, dq_args, kwargs={}, meta=meta, updated=True + ) + else: + return op diff --git a/backends/arm/test/ops/test_avg_pool2d.py b/backends/arm/test/ops/test_avg_pool2d.py index 8310d1e40a4..797ce26ea7a 100644 --- a/backends/arm/test/ops/test_avg_pool2d.py +++ b/backends/arm/test/ops/test_avg_pool2d.py @@ -23,7 +23,7 @@ VgfPipeline, ) -aten_op = "torch.ops.aten.avg_pool2d.default" +aten_op = "avg_pool2d.default" exir_op = "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default" input_t = Tuple[torch.Tensor] @@ -34,6 +34,15 @@ def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) +class BecomesMeanInToEdge(torch.nn.Module): + """This averagepool will be converted to mean when lowering to edge. This causes the decompose_meandim pass to not + trigger until the backend pipeline, which requires extra care. + """ + + def forward(self, x: torch.Tensor): + return torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)) + + test_modules = { "zeros": lambda: (AvgPool2d(4, 2, 0, False), (torch.zeros(1, 16, 50, 32),)), "ones": lambda: (AvgPool2d(4, 2, 0, False, True), (torch.ones(1, 16, 50, 32),)), @@ -110,6 +119,9 @@ def forward(self, *args, **kwargs): AvgPool2d(3, (1, 3), 1, count_include_pad=False), (torch.rand(1, 16, 54, 54),), ), + "becomes_mean_rank3": lambda: (BecomesMeanInToEdge(), (torch.rand(2, 8, 8),)), + "becomes_mean_rank4": lambda: (BecomesMeanInToEdge(), (torch.rand(2, 2, 8, 8),)), + "becomes_mean_rank5": lambda: (BecomesMeanInToEdge(), (torch.rand(2, 2, 8, 8),)), }