|  | 
|  | 1 | +# Copyright 2025 Arm Limited and/or its affiliates. | 
|  | 2 | +# | 
|  | 3 | +# This source code is licensed under the BSD-style license found in the | 
|  | 4 | +# LICENSE file in the root directory of this source tree. | 
|  | 5 | + | 
|  | 6 | +# pyre-unsafe | 
|  | 7 | + | 
|  | 8 | + | 
|  | 9 | +import logging | 
|  | 10 | + | 
|  | 11 | +import torch | 
|  | 12 | +from executorch.backends.arm._passes.arm_pass_utils import create_node | 
|  | 13 | +from executorch.exir.dialects._ops import ops as exir_ops | 
|  | 14 | +from executorch.exir.pass_base import EdgeOpOverload, ExportPass, PassResult | 
|  | 15 | +from torch._subclasses.fake_tensor import FakeTensor | 
|  | 16 | + | 
|  | 17 | + | 
|  | 18 | +logger = logging.getLogger(__name__) | 
|  | 19 | + | 
|  | 20 | + | 
|  | 21 | +class InsertInt32CastsAfterInt64PlaceholdersPass(ExportPass): | 
|  | 22 | +    """ | 
|  | 23 | +    Insert an int64->int32 cast after each int64 placeholder. | 
|  | 24 | +
 | 
|  | 25 | +    Note: Overflow checks are not applied in this pass. It is the user's responsibility to ensure that values fit within | 
|  | 26 | +    the int32 range. | 
|  | 27 | +    """ | 
|  | 28 | + | 
|  | 29 | +    # Ops that require i64 inputs → positions of args to upcast. | 
|  | 30 | +    # Key: op overload; Value: zero-based indices of positional args that must be i64. | 
|  | 31 | +    I64_INPUT_ARG_POSITIONS = { | 
|  | 32 | +        torch.ops.aten.one_hot.default: (0,), | 
|  | 33 | +    } | 
|  | 34 | + | 
|  | 35 | +    def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule): | 
|  | 36 | +        """ | 
|  | 37 | +        If an operator requires int64 inputs but dtype propagation (via call_operator) | 
|  | 38 | +        produced int32, insert a local int32→int64 cast at the call site to satisfy | 
|  | 39 | +        PyTorch's operator input validation. | 
|  | 40 | +        """ | 
|  | 41 | +        modified = False | 
|  | 42 | +        graph = graph_module.graph | 
|  | 43 | +        for node in graph.nodes: | 
|  | 44 | +            if node.op != "call_function": | 
|  | 45 | +                continue | 
|  | 46 | +            if node.target not in self.I64_INPUT_ARG_POSITIONS: | 
|  | 47 | +                continue | 
|  | 48 | + | 
|  | 49 | +            with graph.inserting_before(node): | 
|  | 50 | +                arg_positions = self.I64_INPUT_ARG_POSITIONS.get(node.target) | 
|  | 51 | +                args_list = list(node.args) | 
|  | 52 | +                for pos in arg_positions:  # type: ignore[union-attr] | 
|  | 53 | +                    input_arg = args_list[pos] | 
|  | 54 | +                    to_copy_op = self._get_decomposition(graph) | 
|  | 55 | +                    cast_node = graph_module.graph.create_node( | 
|  | 56 | +                        "call_function", | 
|  | 57 | +                        to_copy_op, | 
|  | 58 | +                        (input_arg,), | 
|  | 59 | +                        {"dtype": torch.int64}, | 
|  | 60 | +                    ) | 
|  | 61 | +                    cast_node.meta["val"] = node.meta["val"].to(torch.int64) | 
|  | 62 | +                    args_list[pos] = cast_node | 
|  | 63 | +                node.args = tuple(args_list) | 
|  | 64 | +                modified = True | 
|  | 65 | +        return modified | 
|  | 66 | + | 
|  | 67 | +    def _graph_uses_edge_ops(self, graph: torch.fx.Graph) -> bool: | 
|  | 68 | +        for n in graph.nodes: | 
|  | 69 | +            if n.op == "call_function": | 
|  | 70 | +                if isinstance(n.target, EdgeOpOverload): | 
|  | 71 | +                    return True | 
|  | 72 | +        return False | 
|  | 73 | + | 
|  | 74 | +    def _get_decomposition(self, graph: torch.fx.Graph): | 
|  | 75 | +        if self._graph_uses_edge_ops(graph): | 
|  | 76 | +            return exir_ops.edge.dim_order_ops._to_dim_order_copy.default | 
|  | 77 | +        else: | 
|  | 78 | +            return torch.ops.dim_order_ops._to_dim_order_copy.default | 
|  | 79 | + | 
|  | 80 | +    def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool: | 
|  | 81 | +        return isinstance(node_val, FakeTensor) and node_val.dtype == dtype | 
|  | 82 | + | 
|  | 83 | +    def _insert_placeholder_i64_to_i32_casts(self, graph_module: torch.fx.GraphModule): | 
|  | 84 | +        modified = False | 
|  | 85 | +        graph = graph_module.graph | 
|  | 86 | +        for node in graph.nodes: | 
|  | 87 | +            if node.op != "placeholder": | 
|  | 88 | +                continue | 
|  | 89 | +            node_val = node.meta["val"] | 
|  | 90 | +            if not self._is_tensor_of_dtype(node_val, torch.int64): | 
|  | 91 | +                continue | 
|  | 92 | + | 
|  | 93 | +            to_copy_op = self._get_decomposition(graph) | 
|  | 94 | +            with graph.inserting_after(node): | 
|  | 95 | +                cast_after = create_node( | 
|  | 96 | +                    graph, | 
|  | 97 | +                    to_copy_op, | 
|  | 98 | +                    args=(node,), | 
|  | 99 | +                    kwargs={ | 
|  | 100 | +                        "dtype": torch.int32, | 
|  | 101 | +                    }, | 
|  | 102 | +                ) | 
|  | 103 | +                users = [user for user in node.users if user != cast_after] | 
|  | 104 | +                for user in users: | 
|  | 105 | +                    user.replace_input_with(node, cast_after) | 
|  | 106 | +                logger.warning( | 
|  | 107 | +                    f"Inserting a casting node {cast_after.name} after {node.name} to cast int64 placeholder" | 
|  | 108 | +                    f" to int32 for {node.name} defined in {node.meta.get('stack_trace','[no stack trace found]')}" | 
|  | 109 | +                ) | 
|  | 110 | +                modified = True | 
|  | 111 | +        return modified | 
|  | 112 | + | 
|  | 113 | +    def call(self, graph_module: torch.fx.GraphModule): | 
|  | 114 | +        modified = False | 
|  | 115 | +        modified |= self._insert_placeholder_i64_to_i32_casts(graph_module) | 
|  | 116 | +        modified |= self._insert_callsite_i32_to_i64_casts(graph_module) | 
|  | 117 | + | 
|  | 118 | +        if modified: | 
|  | 119 | +            graph_module.graph.eliminate_dead_code() | 
|  | 120 | +            graph_module.recompile() | 
|  | 121 | +            graph_module = super().call(graph_module).graph_module | 
|  | 122 | +        return PassResult(graph_module, modified) | 
0 commit comments