|  | 
|  | 1 | +# Copyright (c) Qualcomm Innovation Center, Inc. | 
|  | 2 | +# All rights reserved | 
|  | 3 | +# | 
|  | 4 | +# This source code is licensed under the BSD-style license found in the | 
|  | 5 | +# LICENSE file in the root directory of this source tree. | 
|  | 6 | +import torch | 
|  | 7 | + | 
|  | 8 | +from executorch.exir.pass_base import ExportPass, PassResult | 
|  | 9 | + | 
|  | 10 | +from .utils import copy_nn_module_stack | 
|  | 11 | + | 
|  | 12 | + | 
|  | 13 | +class SliceCopy(torch.nn.Module): | 
|  | 14 | +    def __init__(self, val_shape, shifts, dims): | 
|  | 15 | +        super().__init__() | 
|  | 16 | +        self.val_shape = val_shape | 
|  | 17 | +        if dims[0] is None: | 
|  | 18 | +            self.shifts = [shifts[0] % torch.numel(torch.tensor(val_shape))] | 
|  | 19 | +        else: | 
|  | 20 | +            self.shifts = [shift % val_shape[dim] for shift, dim in zip(shifts, dims)] | 
|  | 21 | +        self.dims = dims | 
|  | 22 | + | 
|  | 23 | +    def forward(self, x): | 
|  | 24 | +        if self.dims[0] is None: | 
|  | 25 | +            y = x.flatten() | 
|  | 26 | +            y = torch.cat((y[-self.shifts[0] :], y[: -self.shifts[0]])) | 
|  | 27 | +            return y.view(self.val_shape) | 
|  | 28 | + | 
|  | 29 | +        for shift, dim in zip(self.shifts, self.dims): | 
|  | 30 | +            x = torch.cat( | 
|  | 31 | +                ( | 
|  | 32 | +                    x[(slice(None),) * dim + (slice(-shift, None),)], | 
|  | 33 | +                    x[(slice(None),) * dim + (slice(0, -shift),)], | 
|  | 34 | +                ), | 
|  | 35 | +                dim=dim, | 
|  | 36 | +            ) | 
|  | 37 | +        return x | 
|  | 38 | + | 
|  | 39 | + | 
|  | 40 | +class DecomposeRoll(ExportPass): | 
|  | 41 | +    """ | 
|  | 42 | +    Decompose roll into slice and cat. | 
|  | 43 | +    """ | 
|  | 44 | + | 
|  | 45 | +    def __init__(self) -> None: | 
|  | 46 | +        super().__init__() | 
|  | 47 | + | 
|  | 48 | +    def call(self, graph_module: torch.fx.GraphModule) -> PassResult: | 
|  | 49 | +        graph = graph_module.graph | 
|  | 50 | +        for node in graph.nodes: | 
|  | 51 | +            if "roll" in str(node.target): | 
|  | 52 | +                input_node, shifts = node.args[0], node.args[1] | 
|  | 53 | +                dims = node.args[2] if len(node.args) == 3 else None | 
|  | 54 | + | 
|  | 55 | +                # Normalize shifts and dims to lists | 
|  | 56 | +                shifts = shifts if isinstance(shifts, (list, tuple)) else [shifts] | 
|  | 57 | +                dims = dims if isinstance(dims, (list, tuple)) else [dims] | 
|  | 58 | + | 
|  | 59 | +                model = SliceCopy(input_node.meta["val"].shape, shifts, dims) | 
|  | 60 | +                decomposed_module = torch.export.export( | 
|  | 61 | +                    model, (input_node.meta["val"],), strict=True | 
|  | 62 | +                ).module() | 
|  | 63 | + | 
|  | 64 | +                with graph.inserting_before(node): | 
|  | 65 | +                    # remap is used to map original node values to new node values, | 
|  | 66 | +                    # which ensures that reference to nodes are correctly updated in the new graph | 
|  | 67 | +                    remap = {"x": input_node} | 
|  | 68 | + | 
|  | 69 | +                    for decomposed_node in decomposed_module.graph.nodes: | 
|  | 70 | +                        copy_nn_module_stack(node, decomposed_node) | 
|  | 71 | +                        # no need to copy existent 'output' | 
|  | 72 | +                        if decomposed_node.op == "output": | 
|  | 73 | +                            for user in node.users.copy(): | 
|  | 74 | +                                # remap | 
|  | 75 | +                                user.replace_input_with( | 
|  | 76 | +                                    node, | 
|  | 77 | +                                    remap[decomposed_node.args[0][0]], | 
|  | 78 | +                                ) | 
|  | 79 | +                        # no need to copy existent placeholders | 
|  | 80 | +                        elif decomposed_node.op == "placeholder": | 
|  | 81 | +                            # replace node map from string to graph node | 
|  | 82 | +                            remap[decomposed_node] = remap.pop(decomposed_node.name) | 
|  | 83 | +                        else: | 
|  | 84 | +                            remap[decomposed_node] = graph.node_copy( | 
|  | 85 | +                                decomposed_node, | 
|  | 86 | +                                arg_transform=lambda x, remap=remap: remap[x], | 
|  | 87 | +                            ) | 
|  | 88 | + | 
|  | 89 | +                    graph.erase_node(node) | 
|  | 90 | + | 
|  | 91 | +        graph.eliminate_dead_code() | 
|  | 92 | +        graph_module.recompile() | 
|  | 93 | +        return PassResult(graph_module, True) | 
0 commit comments