|
| 1 | +# Copyright (c) Qualcomm Innovation Center, Inc. |
| 2 | +# All rights reserved |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import logging |
| 8 | + |
| 9 | +import torch |
| 10 | +from executorch.backends.qualcomm.builders.utils import is_graph_output |
| 11 | +from executorch.backends.qualcomm.utils.constants import QCOM_ORIG_DTYPE |
| 12 | +from executorch.exir import ExirExportedProgram |
| 13 | +from executorch.exir.dialects._ops import ops as exir_ops |
| 14 | +from executorch.exir.pass_base import ExportPass, PassResult |
| 15 | +from executorch.exir.program._program import _get_updated_graph_signature |
| 16 | +from torch._subclasses.fake_tensor import FakeTensor |
| 17 | + |
| 18 | + |
| 19 | +class TensorI64toI32(ExportPass): |
| 20 | + """ |
| 21 | + Insert a cast node to cast dtype from int64 to int32. |
| 22 | + This will only be applied on fake tensors. |
| 23 | + """ |
| 24 | + |
| 25 | + cast_ops = { |
| 26 | + torch.ops.aten.argmin.default, |
| 27 | + } |
| 28 | + |
| 29 | + def __init__(self, edge_program): |
| 30 | + super(TensorI64toI32, self).__init__() |
| 31 | + self.edge_program = edge_program |
| 32 | + |
| 33 | + # pyre-ignore[2] |
| 34 | + def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool: |
| 35 | + return isinstance(node_val, FakeTensor) and node_val.dtype == dtype |
| 36 | + |
| 37 | + def _cast_to_int32(self, core_ep: ExirExportedProgram): |
| 38 | + copy_op = torch.ops.aten._to_copy.default |
| 39 | + for n in core_ep.exported_program.graph.nodes: |
| 40 | + # Keep track of original output dtype so we ensure the dtype of the graph is consistent with nn.Module |
| 41 | + if is_graph_output(n): |
| 42 | + if isinstance(n.meta["val"], tuple): |
| 43 | + dtype_list = [tensor.dtype for tensor in n.meta["val"]] |
| 44 | + n.meta[QCOM_ORIG_DTYPE] = dtype_list |
| 45 | + else: |
| 46 | + n.meta[QCOM_ORIG_DTYPE] = n.meta["val"].dtype |
| 47 | + continue |
| 48 | + if n.target in self.cast_ops: |
| 49 | + node_val = n.meta["val"] |
| 50 | + if self._is_tensor_of_dtype(node_val, torch.int64): |
| 51 | + with core_ep.exported_program.graph.inserting_after(n): |
| 52 | + users = list(n.users.keys()) |
| 53 | + args = (n,) |
| 54 | + cast_node = core_ep.exported_program.graph.create_node( |
| 55 | + "call_function", |
| 56 | + copy_op, |
| 57 | + args, |
| 58 | + {"dtype": torch.int32}, |
| 59 | + ) |
| 60 | + cast_node.meta["val"] = node_val.to(torch.int32) |
| 61 | + cast_node.args = args |
| 62 | + |
| 63 | + for user in users: |
| 64 | + user.replace_input_with(n, cast_node) |
| 65 | + |
| 66 | + core_ep.exported_program._graph_signature = _get_updated_graph_signature( |
| 67 | + core_ep.exported_program._graph_signature, |
| 68 | + core_ep.exported_program.graph_module, |
| 69 | + ) |
| 70 | + core_ep.exported_program._validate() |
| 71 | + |
| 72 | + def _preserve_output_dtype( |
| 73 | + self, exported_program: torch.export.exported_program.ExportedProgram |
| 74 | + ): |
| 75 | + graph_module = exported_program.graph_module |
| 76 | + copy_op = exir_ops.edge.aten._to_copy.default |
| 77 | + for n in graph_module.graph.nodes: |
| 78 | + if is_graph_output(n) and QCOM_ORIG_DTYPE in n.meta: |
| 79 | + if isinstance(n.meta["val"], tuple): |
| 80 | + for i, dtype in enumerate(n.meta[QCOM_ORIG_DTYPE]): |
| 81 | + # TODO: Enable this in future to support OP such as topK |
| 82 | + if n.meta["val"][i].dtype != dtype: |
| 83 | + raise AssertionError( |
| 84 | + "Multi output nodes currently don't support casting dtype back." |
| 85 | + ) |
| 86 | + elif n.meta["val"].dtype != n.meta[QCOM_ORIG_DTYPE]: |
| 87 | + if n.meta[QCOM_ORIG_DTYPE] != torch.int64: |
| 88 | + logging.warning( |
| 89 | + "This pass is intended to maintain output as int64 when nn.Module outputs int64. Other dtype modification is detected. Please ensure this is desired." |
| 90 | + ) |
| 91 | + with graph_module.graph.inserting_after(n): |
| 92 | + orig_dtype = n.meta[QCOM_ORIG_DTYPE] |
| 93 | + node_val = n.meta["val"] |
| 94 | + args = (n,) |
| 95 | + users = list(n.users.keys()) |
| 96 | + output_users = [ |
| 97 | + user for user in users if user.target == "output" |
| 98 | + ] |
| 99 | + cast_node = graph_module.graph.create_node( |
| 100 | + "call_function", |
| 101 | + copy_op, |
| 102 | + args, |
| 103 | + {"dtype": orig_dtype}, |
| 104 | + ) |
| 105 | + cast_node.meta["val"] = node_val.to(orig_dtype) |
| 106 | + cast_node.args = args |
| 107 | + for user in output_users: |
| 108 | + user.replace_input_with(n, cast_node) |
| 109 | + |
| 110 | + def call(self, graph_module: torch.fx.GraphModule): |
| 111 | + # Stage 1: _cast_to_int32 |
| 112 | + # We add to_copy after the desired operations during this stage because the data type only propagates before to_edge. |
| 113 | + # If we don't add to_copy here but do it after to_edge, the next operation after to_copy() will still expect int64 as its output. |
| 114 | + # Stage 2: _preserve_output_dtype |
| 115 | + # We will tag the output dtype during stage 1, and we will ensure that if user expects int64 as output, |
| 116 | + # we need to convert the output back to int64 if it is casted from int64->int32 during stage 1. |
| 117 | + if isinstance(self.edge_program, ExirExportedProgram): |
| 118 | + self._cast_to_int32(self.edge_program) |
| 119 | + self.edge_program.exported_program.graph_module.recompile() |
| 120 | + elif isinstance( |
| 121 | + self.edge_program, torch.export.exported_program.ExportedProgram |
| 122 | + ): |
| 123 | + self._preserve_output_dtype(self.edge_program) |
| 124 | + else: |
| 125 | + raise AssertionError( |
| 126 | + "Should be ExirExportedProgram at stage 1 and torch.export.exported_program.ExportedProgram at stage 2" |
| 127 | + ) |
| 128 | + return PassResult(graph_module, True) |
0 commit comments