| 
 | 1 | +# Copyright 2025 Arm Limited and/or its affiliates.  | 
 | 2 | +#  | 
 | 3 | +# This source code is licensed under the BSD-style license found in the  | 
 | 4 | +# LICENSE file in the root directory of this source tree.  | 
 | 5 | + | 
 | 6 | +# pyre-unsafe  | 
 | 7 | + | 
 | 8 | + | 
 | 9 | +import logging  | 
 | 10 | +from math import prod  | 
 | 11 | + | 
 | 12 | +import torch  | 
 | 13 | +from executorch.exir.dialects._ops import ops as exir_ops  | 
 | 14 | +from executorch.exir.pass_base import ExportPass, PassResult  | 
 | 15 | + | 
 | 16 | +from .arm_pass_utils import create_node, get_first_fake_tensor  | 
 | 17 | + | 
 | 18 | +logger = logging.getLogger(__name__)  | 
 | 19 | +logger.setLevel(logging.WARNING)  | 
 | 20 | + | 
 | 21 | + | 
 | 22 | +class DecomposeEmbeddingPass(ExportPass):  | 
 | 23 | +    """  | 
 | 24 | +    This pass decomposes embedding into index_select.  | 
 | 25 | +
  | 
 | 26 | +    Example:  | 
 | 27 | +          o = embedding(w, i)  | 
 | 28 | +    Becomes:  | 
 | 29 | +          i = view_copy(i)  # flatten indices  | 
 | 30 | +          o = index_select(w, i)  | 
 | 31 | +          o = view_copy(o)  # reshape back output  | 
 | 32 | +    Note:  | 
 | 33 | +         i = indices is expected to be int32 before this pass  | 
 | 34 | +    """  | 
 | 35 | + | 
 | 36 | +    aten_ops = (torch.ops.aten.embedding.default,)  | 
 | 37 | +    edge_ops = (exir_ops.edge.aten.embedding.default,)  | 
 | 38 | + | 
 | 39 | +    def get_decomposition(self, op):  | 
 | 40 | +        if op in self.aten_ops:  | 
 | 41 | +            return (  | 
 | 42 | +                torch.ops.aten.view_copy.default,  | 
 | 43 | +                torch.ops.aten.index_select.default,  | 
 | 44 | +            )  | 
 | 45 | + | 
 | 46 | +        if op in self.edge_ops:  | 
 | 47 | +            return (  | 
 | 48 | +                exir_ops.edge.aten.view_copy.default,  | 
 | 49 | +                exir_ops.edge.aten.index_select.default,  | 
 | 50 | +            )  | 
 | 51 | +        raise RuntimeError(  | 
 | 52 | +            f"[{self.__class__.__name__}] Can't get decomposition for op {op}"  | 
 | 53 | +        )  | 
 | 54 | + | 
 | 55 | +    def call(self, graph_module):  | 
 | 56 | +        graph = graph_module.graph  | 
 | 57 | +        modified_graph = False  | 
 | 58 | + | 
 | 59 | +        for node in graph.nodes:  | 
 | 60 | +            if node.op != "call_function":  | 
 | 61 | +                continue  | 
 | 62 | +            if node.target not in self.aten_ops + self.edge_ops:  | 
 | 63 | +                continue  | 
 | 64 | + | 
 | 65 | +            args = node.args  | 
 | 66 | + | 
 | 67 | +            weights = args[0]  | 
 | 68 | +            indices = args[1]  | 
 | 69 | + | 
 | 70 | +            weights_shape = get_first_fake_tensor(weights).shape  | 
 | 71 | +            indices_shape = get_first_fake_tensor(indices).shape  | 
 | 72 | + | 
 | 73 | +            output_shape = torch.Size(list(indices_shape) + [weights_shape[1]])  | 
 | 74 | +            if output_shape != get_first_fake_tensor(node).shape:  | 
 | 75 | +                raise RuntimeError(  | 
 | 76 | +                    f"[{self.__class__.__name__}] Unexpected output shape mismatch {output_shape} "  | 
 | 77 | +                    "!= {get_first_fake_tensor(node).shape}"  | 
 | 78 | +                )  | 
 | 79 | + | 
 | 80 | +            view_copy_op, index_select_op = self.get_decomposition(node.target)  | 
 | 81 | + | 
 | 82 | +            with graph.inserting_before(node):  | 
 | 83 | +                reshaped_indices = [prod(list(indices_shape))]  | 
 | 84 | +                flattened_indices = create_node(  | 
 | 85 | +                    graph=graph,  | 
 | 86 | +                    op_target=view_copy_op,  | 
 | 87 | +                    args=(indices, reshaped_indices),  | 
 | 88 | +                )  | 
 | 89 | +                node.replace_input_with(indices, flattened_indices)  | 
 | 90 | + | 
 | 91 | +                index_select = create_node(  | 
 | 92 | +                    graph=graph,  | 
 | 93 | +                    op_target=index_select_op,  | 
 | 94 | +                    args=(weights, 0, flattened_indices),  | 
 | 95 | +                )  | 
 | 96 | +                node.replace_all_uses_with(index_select)  | 
 | 97 | +                graph.erase_node(node)  | 
 | 98 | + | 
 | 99 | +            with graph.inserting_after(index_select):  | 
 | 100 | +                restored_output = create_node(  | 
 | 101 | +                    graph,  | 
 | 102 | +                    view_copy_op,  | 
 | 103 | +                )  | 
 | 104 | +                restored_output.args = (  | 
 | 105 | +                    index_select,  | 
 | 106 | +                    output_shape,  | 
 | 107 | +                )  | 
 | 108 | +                original_users = [  | 
 | 109 | +                    user for user in index_select.users if user != restored_output  | 
 | 110 | +                ]  | 
 | 111 | +                for user in original_users:  | 
 | 112 | +                    user.replace_input_with(index_select, restored_output)  | 
 | 113 | + | 
 | 114 | +            modified_graph = True  | 
 | 115 | + | 
 | 116 | +        if modified_graph:  | 
 | 117 | +            graph.eliminate_dead_code()  | 
 | 118 | +            graph_module.recompile()  | 
 | 119 | +            graph_module = super().call(graph_module).graph_module  | 
 | 120 | +        return PassResult(graph_module, modified_graph)  | 
0 commit comments