diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 6ed578c27ea..02b0ba8e386 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -51,7 +51,6 @@ from .decompose_sqrt_pass import DecomposeSqrtPass # noqa from .decompose_sum_pass import DecomposeSumPass # noqa from .decompose_var_pass import DecomposeVarPass # noqa -from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa from .fold_qdq_with_annotated_qparams_pass import ( # noqa FoldAndAnnotateQParamsPass, QuantizeOperatorArguments, diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 3a0e75fad05..6570891509f 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -56,7 +56,6 @@ DecomposeSqrtPass, DecomposeSumPass, DecomposeVarPass, - DecorateFp32toInt32CastingPass, FoldAndAnnotateQParamsPass, FuseBatchnorm2DPass, FuseConstantArgsPass, @@ -201,9 +200,6 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul self.add_pass(MatchArgRanksPass(exported_program)) self.add_pass(DecomposeAdaptiveAvgPool2dPass()) self.add_pass(DecomposeAvgPool2d()) - self.add_pass( - DecorateFp32toInt32CastingPass() - ) # Require that no new fp32->int32 is introduced after this pass self.add_pass(ComputeConstantOpsAOT(exported_program)) self.add_pass(DecomposeGroupedConv()) diff --git a/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py b/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py deleted file mode 100644 index d6f7ac2ceac..00000000000 --- a/backends/arm/_passes/decorate_fp32_to_int32_casting_pass.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# pyre-unsafe - - -import torch -from executorch.backends.arm._passes import ArmPass -from executorch.backends.arm._passes.arm_pass_utils import get_node_arg -from executorch.exir.dialects._ops import ops as exir_ops - - -def _get_decorated_ops(op): - if op in DecorateFp32toInt32CastingPass.targets: - return ( - exir_ops.edge.aten.full.default, - exir_ops.edge.aten.ge.Tensor, - exir_ops.edge.aten.floor.default, - exir_ops.edge.aten.ceil.default, - exir_ops.edge.aten.where.self, - ) - else: - raise RuntimeError(f"Can't get decorated ops for op {op}") - - -class DecorateFp32toInt32CastingPass(ArmPass): - """ - To lower pytorch fp32 -> int32 casting to TOSA, - we need to transform the value with Ceil, Floor, and Where. - Before: - output = to_copy(x, dtype=torch.int32) - After: - %zero = full((1,), 0.0, dtype=torch.float32) - is_non_negative = x >= %zero - floor_x = floor(x) - ceil_x = ceil(x) - decorated_x = where(is_non_negative, floor_x, ceil_x) - output = to_copy(decorated_x, dtype=torch.int32) - """ - - targets = [ - exir_ops.edge.aten._to_copy.default, - exir_ops.edge.dim_order_ops._to_dim_order_copy.default, - ] - - def call_operator(self, op, args, kwargs, meta): - if op not in self.targets: - return super().call_operator(op, args, kwargs, meta) - - input = get_node_arg(args, 0) - input_dtype = input.node.meta["val"].dtype - output_dtype = meta["val"].dtype - - if not (input_dtype == torch.float32 and output_dtype == torch.int32): - return super().call_operator(op, args, kwargs, meta) - - op_full, op_ge, op_floor, op_ceil, op_where = _get_decorated_ops(op) - - zero = super().call_operator( - op_full, - args=((1,) * len(meta["val"].size()), 0.0), - kwargs={"dtype": torch.float32}, - meta=meta, - updated=True, - ) - - is_non_negative = super().call_operator( - op_ge, (input, zero), {}, meta, updated=True - ) - floor_x = super().call_operator(op_floor, (input,), {}, meta, updated=True) - ceil_x = super().call_operator(op_ceil, (input,), {}, meta, updated=True) - decorated_x = super().call_operator( - op_where, (is_non_negative, floor_x, ceil_x), {}, meta, updated=True - ) - - return super().call_operator(op, (decorated_x,), kwargs, meta, updated=True) diff --git a/backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py b/backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py deleted file mode 100644 index 25312b89748..00000000000 --- a/backends/arm/test/passes/test_decorate_fp32_to_int32_casting_pass.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2025 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Tuple - -import torch -from executorch.backends.arm.test import common - -from executorch.backends.arm.test.tester.test_pipeline import ( - OpNotSupportedPipeline, - TosaPipelineMI, -) - -input_t1 = Tuple[torch.Tensor] # Input x - - -class FP32ToINT32Casting(torch.nn.Module): - def __init__(self, target_dtype): - super().__init__() - self.target_dtype = target_dtype - - def forward(self, x: torch.Tensor): - return x.to(self.target_dtype) - - -test_data_fp32_input = { - "fp32_input_rank1": lambda: ( - torch.rand((4), dtype=torch.float32), - torch.int32, - ), - "fp32_input_rank2": lambda: ( - torch.rand((3, 4), dtype=torch.float32), - torch.int32, - ), - "fp32_input_rank3": lambda: ( - torch.rand((2, 3, 4), dtype=torch.float32), - torch.int32, - ), - "fp32_input_rank4": lambda: ( - torch.rand((1, 2, 3, 4), dtype=torch.float32), - torch.int32, - ), -} - - -@common.parametrize("test_data", test_data_fp32_input) -def test_decorate_fp32_to_int32_casting_tosa_MI(test_data: Tuple): - test_tensor, target_dtype = test_data() - module = FP32ToINT32Casting(target_dtype) - - pipeline = TosaPipelineMI[input_t1]( - module, - (test_tensor,), - aten_op=[], - exir_op=[], - ) - pipeline.run() - - -@common.parametrize("test_data", test_data_fp32_input) -def test_decorate_fp32_to_int32_casting_tosa_BI(test_data: Tuple): - """ - Casting operation involving floating-point dtypes will be rejected in BI/INT profile. - Therefore, the DecorateFp32toInt32CastingPass is not required in this profile. - Add a BI test to ensure that such casting is rejected as expected. - """ - test_tensor, target_dtype = test_data() - module = FP32ToINT32Casting(target_dtype) - - pipeline = OpNotSupportedPipeline[input_t1]( - module, - (test_tensor,), - { - "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1 - }, - quantize=True, - ) - pipeline.run()