From 4f2fc70893240ae2fab8f98a8660dc5e8c92c4f0 Mon Sep 17 00:00:00 2001 From: Eashan Garg Date: Mon, 30 Jun 2025 09:14:40 -0700 Subject: [PATCH] Pass to replace Adaptive Avg. Pool with Aten Avg. Pool (#10818) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/10818 Seeing exir_ops.edge.aten._adaptive_avg_pool2d.default nodes in some graphs, pass to replace these with exir_ops.edge.aten.avg_pool2d.default Reviewed By: mcremon-meta Differential Revision: D74559775 --- backends/cadence/aot/replace_ops.py | 62 +++++++++++ .../aot/tests/test_replace_ops_passes.py | 100 ++++++++++++++++++ 2 files changed, 162 insertions(+) diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index d85a0cc9be4..3950f1852df 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -16,6 +16,7 @@ # pyre-unsafe +import logging import math import operator from operator import neg @@ -2346,6 +2347,66 @@ def resolve_full_arg(self, x_arg, const_arg): return const_arg +@register_cadence_pass(CadencePassAttribute(opt_level=0)) +class ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(ExportPass): + """ + Replace the aten adaptive avg_pool op with the aten avg_pool2d op. + """ + + def call_operator(self, op, args, kwargs, meta): + # Only continue for avg_pool op + if op not in {exir_ops.edge.aten._adaptive_avg_pool2d.default}: + return super().call_operator(op, args, kwargs, meta) + + # Get the input tensor + in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0] + # Permute NCHW to NHWC for computation + in_tensor_permuted = in_tensor.permute(0, 2, 3, 1) + in_tensor_shape = in_tensor_permuted.shape + + output_size = args[1] + num_dims = len(output_size) + + # TODO: If in_tensor_shape is not a multiple of output size, + # this pass will not work. T224984800 + dim_multiples = [ + (in_tensor_shape[i + 1] % output_size[i]) == 0 for i in range(num_dims) + ] + if not all(dim_multiples): + logging.info( + f"Unable to replace adaptive average pool with average pool. Input tensor shape of {in_tensor_shape} is not a multiple of output size: {output_size}" + ) + return super().call_operator(op, args, kwargs, meta) + + # Compute stride and kernel_size, then set default values for other arguments + stride = [(in_tensor_shape[i + 1] // output_size[i]) for i in range(num_dims)] + kernel_size = [ + in_tensor_shape[i + 1] - (output_size[i] - 1) * stride[i] + for i in range(num_dims) + ] + padding = [0] * num_dims + ceil_mode = False + count_include_pad = True + divisor_override = None + + # Create a new avg_pool node with the updated args + new_args = ( + args[0], + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) + return super().call_operator( + exir_ops.edge.aten.avg_pool2d.default, + new_args, + kwargs, + meta, + ) + + # This class encapsulates all the functions that replace/switch one op in the # graph with another. class CadenceReplaceOpsInGraph: @@ -2382,6 +2443,7 @@ class CadenceReplaceOpsInGraph: ReplacePT2QuantWithCadenceQuantPass, ReplacePT2DequantWithCadenceDequantPass, ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, + ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, ReplaceAtenAvgPoolWithJarvisAvgPoolPass, ReplaceWhereWithFullArgsWithWhereScalar, ReplaceAtenApproxGeluWithApproxGeluPass, diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 6d12c991d6d..0537889d2c2 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -19,6 +19,7 @@ from executorch.backends.cadence.aot.replace_ops import ( ForceChannelLastForConvPass, MakeSliceAndCatDimOutermostPass, + ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass, ReplaceAddMMWithLinearPass, ReplaceAtenApproxGeluWithApproxGeluPass, ReplaceAtenConvolutionWithJarvisConvolutionPass, @@ -1936,3 +1937,102 @@ def test_extract_mul_argument_to_full( }, ) ) + + +class TestReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(unittest.TestCase): + def _get_adaptive_avg_pool_gm( + self, input_shape: Tuple[int, int, int, int], output_shape: Tuple[int, int] + ) -> torch.fx.GraphModule: + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(*input_shape)) + adaptive_avg_pool2d = builder.call_operator( + exir_ops.edge.aten._adaptive_avg_pool2d.default, (x, output_shape) + ) + builder.output([adaptive_avg_pool2d]) + return builder.get_graph_module() + + def test_replace_adaptive_avg_pool_with_aten_avg_pool(self) -> None: + gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8)) + self.assertEqual( + len( + gm.graph.find_nodes( + op="call_function", + target=exir_ops.edge.aten._adaptive_avg_pool2d.default, + ) + ), + 1, + ) + self.assertEqual( + len( + gm.graph.find_nodes( + op="call_function", + target=exir_ops.edge.aten.avg_pool2d.default, + ) + ), + 0, + ) + p = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass() + updated_gm = p.call(gm).graph_module + self.assertEqual( + len( + updated_gm.graph.find_nodes( + op="call_function", + target=exir_ops.edge.aten._adaptive_avg_pool2d.default, + ) + ), + 0, + ) + avg_pool2d_nodes = updated_gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.avg_pool2d.default + ) + self.assertEqual( + len(avg_pool2d_nodes), + 1, + ) + avg_pool2d_node = avg_pool2d_nodes[0] + + self.assertEqual(avg_pool2d_node.args[1], [16, 16]) # kernel_size is 16x16 + self.assertEqual(avg_pool2d_node.args[2], [16, 16]) # stride is 16, 16 + self.assertEqual(avg_pool2d_node.args[3], [0, 0]) # padding is 0, 0 + self.assertEqual(avg_pool2d_node.args[4], False) # ceil_mode is False + self.assertEqual(avg_pool2d_node.args[5], True) # count_include_pad is True + self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None + + def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None: + gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9)) + self.assertEqual( + len( + gm.graph.find_nodes( + op="call_function", + target=exir_ops.edge.aten._adaptive_avg_pool2d.default, + ) + ), + 1, + ) + self.assertEqual( + len( + gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.avg_pool2d.default + ) + ), + 0, + ) + # Shapes are not multiples of each other, so pass will not trigger + p = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass() + updated_gm = p.call(gm).graph_module + self.assertEqual( + len( + updated_gm.graph.find_nodes( + op="call_function", + target=exir_ops.edge.aten._adaptive_avg_pool2d.default, + ) + ), + 1, + ) + avg_pool2d_nodes = updated_gm.graph.find_nodes( + op="call_function", target=exir_ops.edge.aten.avg_pool2d.default + ) + self.assertEqual( + len(avg_pool2d_nodes), + 0, + )