|
| 1 | +# Copyright 2025 Arm Limited and/or its affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the BSD-style license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | + |
| 6 | +from math import prod |
| 7 | + |
| 8 | +import torch |
| 9 | +from executorch.backends.arm._passes import ArmPass |
| 10 | +from executorch.backends.arm._passes.arm_pass_utils import create_node |
| 11 | +from executorch.backends.arm._passes.quant_args import QuantArgs |
| 12 | + |
| 13 | +from executorch.backends.transforms.utils import create_constant_placeholder |
| 14 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 15 | +from executorch.exir.pass_base import PassResult |
| 16 | +from torch.export.graph_signature import InputKind |
| 17 | + |
| 18 | + |
| 19 | +class DecomposeCumsumPass(ArmPass): |
| 20 | + """ |
| 21 | + Decomposes cumsum into a 1D convolution with a kernel of ones. |
| 22 | +
|
| 23 | + For example, the cumsum of an input tensor [1, 1] is [1, 1 + 1] = [1, 2]. |
| 24 | + To decompose this, take the input tensor and pre-padded with len(input)-1 zeros and |
| 25 | + slided over with a kernel [1,1], of length len(input): |
| 26 | +
|
| 27 | + Input: [0, 1, 1] |
| 28 | + Kernel: [1, 1] = [1] |
| 29 | + [1, 1] = [2] |
| 30 | +
|
| 31 | + Since pytorch only supports symmetric padding, in reality the result will have |
| 32 | + an additional 1 calculated at the end, which leads to an required extra slice op. |
| 33 | +
|
| 34 | + To extend this to higher dimensions, the input is reshaped to [N, C, H, W] with |
| 35 | + N = <dims before cumsum dim> |
| 36 | + C = 1 |
| 37 | + H = <cumsum dim> |
| 38 | + W = <dims after cumsum dim> |
| 39 | + And the convolution is applied over dimension H. |
| 40 | + """ |
| 41 | + |
| 42 | + def call(self, graph_module): |
| 43 | + graph = graph_module.graph |
| 44 | + targets = (exir_ops.edge.aten.cumsum.default, torch.ops.aten.cumsum.default) |
| 45 | + modified = False |
| 46 | + for node in list(graph.nodes): |
| 47 | + if node.op != "call_function" or node.target not in targets: |
| 48 | + continue |
| 49 | + |
| 50 | + if len(node.args) != 2: |
| 51 | + raise ValueError( |
| 52 | + "Cumsum node should have exactly two arguments: input and dim." |
| 53 | + ) |
| 54 | + |
| 55 | + # Get node data |
| 56 | + input_node, dim = node.args |
| 57 | + val = node.meta.get("val") |
| 58 | + original_shape = list(val.shape) |
| 59 | + dtype = input_node.meta.get("val").dtype |
| 60 | + dim = dim % len(original_shape) |
| 61 | + |
| 62 | + # Compute shapes |
| 63 | + pre_cumsum_dim = prod(original_shape[:dim]) if dim > 0 else 1 |
| 64 | + cumsum_dim = original_shape[dim] |
| 65 | + post_cumsum_dim = ( |
| 66 | + prod(original_shape[dim + 1 :]) if dim < len(original_shape) - 1 else 1 |
| 67 | + ) |
| 68 | + conv_shape = [ |
| 69 | + pre_cumsum_dim, |
| 70 | + 1, |
| 71 | + cumsum_dim, |
| 72 | + post_cumsum_dim, |
| 73 | + ] |
| 74 | + pad_shape = [original_shape[dim] - 1, 0] |
| 75 | + weight_shape = [1, 1, original_shape[dim], 1] |
| 76 | + |
| 77 | + # Create convolution weight |
| 78 | + with graph.inserting_before(list(graph.nodes)[0]): |
| 79 | + weight_data = torch.ones(size=weight_shape, dtype=dtype) |
| 80 | + weight_node = create_constant_placeholder( |
| 81 | + self.exported_program, |
| 82 | + graph, |
| 83 | + node.name + "_kernel", |
| 84 | + InputKind.PARAMETER, |
| 85 | + weight_data, |
| 86 | + ) |
| 87 | + |
| 88 | + # Create decomposed nodes |
| 89 | + view_op = exir_ops.edge.aten.view_copy.default |
| 90 | + conv_op = exir_ops.edge.aten.convolution.default |
| 91 | + slice_op = exir_ops.edge.aten.slice_copy.Tensor |
| 92 | + with graph.inserting_before(node): |
| 93 | + # Reshape to 4D with |
| 94 | + view_args = (input_node, conv_shape) |
| 95 | + view_node = create_node(graph, view_op, args=view_args, from_node=node) |
| 96 | + |
| 97 | + conv_args = ( |
| 98 | + view_node, |
| 99 | + weight_node, |
| 100 | + None, |
| 101 | + [1, 1], |
| 102 | + pad_shape, |
| 103 | + [1, 1], |
| 104 | + False, |
| 105 | + [0], |
| 106 | + 1, |
| 107 | + ) |
| 108 | + conv_node = create_node(graph, conv_op, args=conv_args, from_node=node) |
| 109 | + |
| 110 | + # The convolution is inserted after quantization, so we need to set our |
| 111 | + # own quantization parameters for the weights here. However since the |
| 112 | + # data is ones directly created as int8, they already have correct scale |
| 113 | + # and so no scaling needs to be done, i.e. set scale=1.0, zero_point=0.0 |
| 114 | + if ( |
| 115 | + "input_qparams" in conv_node.meta |
| 116 | + and len(conv_node.meta["input_qparams"]) > 0 |
| 117 | + ): |
| 118 | + qparams = QuantArgs(1.0, 0.0, -128, 127, torch.int8) |
| 119 | + conv_node.meta["input_qparams"][1] = qparams |
| 120 | + |
| 121 | + slice_args = (conv_node, 2, 0, original_shape[dim]) |
| 122 | + slice_node = create_node( |
| 123 | + graph, slice_op, args=slice_args, from_node=node |
| 124 | + ) |
| 125 | + |
| 126 | + view_original_args = (slice_node, original_shape) |
| 127 | + view_original_node = create_node( |
| 128 | + graph, view_op, args=view_original_args, from_node=node |
| 129 | + ) |
| 130 | + |
| 131 | + # Replace and remove original |
| 132 | + node.replace_all_uses_with(view_original_node) |
| 133 | + graph.erase_node(node) |
| 134 | + modified = True |
| 135 | + |
| 136 | + if modified: |
| 137 | + # Cleanup |
| 138 | + graph.eliminate_dead_code() |
| 139 | + graph_module.recompile() |
| 140 | + # Apply any operator-level transforms |
| 141 | + graph_module = super().call(graph_module).graph_module |
| 142 | + return PassResult(graph_module, modified) |
0 commit comments