|  | 
|  | 1 | +# Copyright 2025 Arm Limited and/or its affiliates. | 
|  | 2 | +# | 
|  | 3 | +# This source code is licensed under the BSD-style license found in the | 
|  | 4 | +# LICENSE file in the root directory of this source tree. | 
|  | 5 | + | 
|  | 6 | + | 
|  | 7 | +import torch | 
|  | 8 | +from executorch.backends.arm.operators.operator_validation_utils import ( | 
|  | 9 | +    adjust_pooling_pad_if_needed, | 
|  | 10 | +) | 
|  | 11 | +from executorch.exir.dialects._ops import ops as exir_ops | 
|  | 12 | +from executorch.exir.pass_base import ExportPass | 
|  | 13 | + | 
|  | 14 | +edge_div_ops = (exir_ops.edge.aten.avg_pool2d.default,) | 
|  | 15 | +aten_div_ops = (torch.ops.aten.avg_pool2d.default,) | 
|  | 16 | + | 
|  | 17 | + | 
|  | 18 | +def get_decomposition(op) -> tuple: | 
|  | 19 | +    if op in edge_div_ops: | 
|  | 20 | +        return ( | 
|  | 21 | +            exir_ops.edge.aten.full.default, | 
|  | 22 | +            exir_ops.edge.aten.cat.default, | 
|  | 23 | +            exir_ops.edge.aten.avg_pool2d.default, | 
|  | 24 | +            exir_ops.edge.aten.mul.Tensor, | 
|  | 25 | +        ) | 
|  | 26 | +    if op in aten_div_ops: | 
|  | 27 | +        return ( | 
|  | 28 | +            torch.ops.aten.full.default, | 
|  | 29 | +            torch.ops.aten.cat.default, | 
|  | 30 | +            torch.ops.aten.avg_pool2d.default, | 
|  | 31 | +            torch.ops.aten.mul.Tensor, | 
|  | 32 | +        ) | 
|  | 33 | +    raise RuntimeError(f"Can't get div decomposition for op {op}") | 
|  | 34 | + | 
|  | 35 | + | 
|  | 36 | +class DecomposeAvgPool2d(ExportPass): | 
|  | 37 | +    """ """ | 
|  | 38 | + | 
|  | 39 | +    def call_operator(self, op, args, kwargs, meta): | 
|  | 40 | +        if op not in (edge_div_ops + aten_div_ops): | 
|  | 41 | +            return super().call_operator(op, args, kwargs, meta) | 
|  | 42 | + | 
|  | 43 | +        full_op, cat_op, avgpool_op, mul_op = get_decomposition(op) | 
|  | 44 | + | 
|  | 45 | +        x = args[0] | 
|  | 46 | +        kernel_h, kernel_w = args[1] | 
|  | 47 | +        kernel_size = kernel_h * kernel_w | 
|  | 48 | +        stride_h, stride_w = args[2] | 
|  | 49 | +        pad_h, pad_w = new_pad_h, new_pad_w = args[3] if len(args) > 3 else (0, 0) | 
|  | 50 | +        ceil_mode = args[4] if len(args) > 4 else False | 
|  | 51 | +        count_include_pad = args[5] if len(args) > 5 else True | 
|  | 52 | +        divisor_override = args[6] if len(args) > 6 else None | 
|  | 53 | + | 
|  | 54 | +        n, c, h, w = x.data.shape | 
|  | 55 | +        post_pad_w, post_pad_h = (0, 0) | 
|  | 56 | + | 
|  | 57 | +        # Count_include_pad == False means that we use a different divisor for edge elements | 
|  | 58 | +        # When divisor_override is set, this will be overriden anyways. | 
|  | 59 | +        # It is easier to replace a constant divisor, so set count_include_pad == True | 
|  | 60 | +        if divisor_override is not None: | 
|  | 61 | +            count_include_pad = True | 
|  | 62 | + | 
|  | 63 | +        # Add width padding manually if count_include_pad | 
|  | 64 | +        if count_include_pad and pad_w > 0: | 
|  | 65 | +            pre_pad_shape = [n, c, h, pad_w] | 
|  | 66 | +            pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta) | 
|  | 67 | + | 
|  | 68 | +            if ceil_mode and divisor_override is None: | 
|  | 69 | +                post_pad_w = pad_w | 
|  | 70 | +            else: | 
|  | 71 | +                post_pad_w = adjust_pooling_pad_if_needed( | 
|  | 72 | +                    w, kernel_w, stride_w, pad_w, ceil_mode | 
|  | 73 | +                ) | 
|  | 74 | + | 
|  | 75 | +            if post_pad_w > 0: | 
|  | 76 | +                post_pad_shape = [n, c, h, post_pad_w] | 
|  | 77 | +                post_pad = super().call_operator( | 
|  | 78 | +                    full_op, (post_pad_shape, 0.0), kwargs, meta | 
|  | 79 | +                ) | 
|  | 80 | +                cat_nodes = [pre_pad, x, post_pad] | 
|  | 81 | +            else: | 
|  | 82 | +                cat_nodes = [pre_pad, x] | 
|  | 83 | + | 
|  | 84 | +            x = super().call_operator(cat_op, (cat_nodes, 3), kwargs, meta) | 
|  | 85 | +            new_pad_w = 0 | 
|  | 86 | + | 
|  | 87 | +        # Add height padding manually if count_include_pad | 
|  | 88 | +        if count_include_pad and pad_h > 0: | 
|  | 89 | +            pre_pad_shape = [n, c, pad_h, w + pad_w + post_pad_w] | 
|  | 90 | +            pre_pad = super().call_operator(full_op, (pre_pad_shape, 0.0), kwargs, meta) | 
|  | 91 | + | 
|  | 92 | +            if ceil_mode and divisor_override is None: | 
|  | 93 | +                post_pad_h = pad_h | 
|  | 94 | +            else: | 
|  | 95 | +                post_pad_h = adjust_pooling_pad_if_needed( | 
|  | 96 | +                    h, kernel_h, stride_h, pad_h, ceil_mode | 
|  | 97 | +                ) | 
|  | 98 | + | 
|  | 99 | +            if post_pad_h > 0: | 
|  | 100 | +                post_pad_shape = [n, c, post_pad_h, w + pad_w + post_pad_w] | 
|  | 101 | +                post_pad = super().call_operator( | 
|  | 102 | +                    full_op, (post_pad_shape, 0.0), kwargs, meta | 
|  | 103 | +                ) | 
|  | 104 | +                cat_nodes = [pre_pad, x, post_pad] | 
|  | 105 | +            else: | 
|  | 106 | +                cat_nodes = [pre_pad, x] | 
|  | 107 | + | 
|  | 108 | +            x = super().call_operator(cat_op, (cat_nodes, 2), kwargs, meta) | 
|  | 109 | +            new_pad_h = 0 | 
|  | 110 | + | 
|  | 111 | +        avgpool_args = (x, args[1], args[2], [new_pad_h, new_pad_w], ceil_mode, False) | 
|  | 112 | +        x = super().call_operator(avgpool_op, avgpool_args, kwargs, meta) | 
|  | 113 | + | 
|  | 114 | +        # Multiply by factor (kernel_size / divisor_override) if divisor_override | 
|  | 115 | +        if divisor_override is not None and divisor_override != kernel_size: | 
|  | 116 | +            override_multiplier = super().call_operator( | 
|  | 117 | +                full_op, ([1, 1, 1, 1], kernel_size / divisor_override), kwargs, meta | 
|  | 118 | +            ) | 
|  | 119 | +            x = super().call_operator(mul_op, (x, override_multiplier), kwargs, meta) | 
|  | 120 | + | 
|  | 121 | +        return x | 
0 commit comments