diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index 9f966596de3..3554fc0954c 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -49,6 +49,7 @@ def __init__(self, exported_program): exir_ops.edge.aten.bitwise_left_shift.Tensor, exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.ge.Tensor, exir_ops.edge.aten.lt.Tensor, exir_ops.edge.aten.pow.Tensor_Tensor, exir_ops.edge.aten.where.self, diff --git a/backends/arm/_passes/replace_scalar_with_tensor_pass.py b/backends/arm/_passes/replace_scalar_with_tensor_pass.py index f24d53215f5..fed72e664f5 100644 --- a/backends/arm/_passes/replace_scalar_with_tensor_pass.py +++ b/backends/arm/_passes/replace_scalar_with_tensor_pass.py @@ -27,6 +27,7 @@ exir_ops.edge.aten.__lshift__.Scalar: exir_ops.edge.aten.bitwise_left_shift.Tensor, exir_ops.edge.aten.eq.Scalar: exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.gt.Scalar: exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.ge.Scalar: exir_ops.edge.aten.ge.Tensor, exir_ops.edge.aten.lt.Scalar: exir_ops.edge.aten.lt.Tensor, torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor, torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor, @@ -36,6 +37,7 @@ torch.ops.aten.__lshift__.Scalar: torch.ops.aten.bitwise_left_shift.Tensor, torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor, torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor, + torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor, torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor, } diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py index b716acd0867..7276e8efffe 100644 --- a/backends/arm/operator_support/ethos_u55_support.py +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -134,6 +134,7 @@ class EthosU55NotSupported(OperatorSupportBase): exir_ops.edge.aten.eq.Tensor, exir_ops.edge.aten.eq.Scalar, exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.ge.Scalar, exir_ops.edge.aten.gt.Tensor, exir_ops.edge.aten.gt.Scalar, exir_ops.edge.aten.le.Tensor, diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 1cec35923de..cd7a550b436 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -176,6 +176,7 @@ def is_node_supported( exir_ops.edge.aten.full.default, exir_ops.edge.aten.full_like.default, exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.ge.Scalar, exir_ops.edge.aten.gt.Tensor, exir_ops.edge.aten.gt.Scalar, exir_ops.edge.aten.le.Tensor, diff --git a/backends/arm/test/ops/test_ge.py b/backends/arm/test/ops/test_ge.py index a6193f6ea08..7bcd2c923a4 100644 --- a/backends/arm/test/ops/test_ge.py +++ b/backends/arm/test/ops/test_ge.py @@ -5,7 +5,6 @@ from typing import Tuple -import pytest import torch from executorch.backends.arm.test import common @@ -16,13 +15,14 @@ TosaPipelineMI, ) -aten_op = "torch.ops.aten.ge.Tensor" -exir_op = "executorch_exir_dialects_edge__ops_aten_ge_Tensor" - input_t = Tuple[torch.Tensor] class GreaterEqual(torch.nn.Module): + aten_op_tensor = "torch.ops.aten.ge.Tensor" + aten_op_scalar = "torch.ops.aten.ge.Scalar" + exir_op = "executorch_exir_dialects_edge__ops_aten_ge_Tensor" + def __init__(self, input, other): super().__init__() self.input_ = input @@ -31,7 +31,7 @@ def __init__(self, input, other): def forward( self, input_: torch.Tensor, - other_: torch.Tensor, + other_: torch.Tensor | int | float, ): return input_ >= other_ @@ -39,98 +39,143 @@ def get_inputs(self): return (self.input_, self.other_) -op_ge_rank1_ones = GreaterEqual( +op_ge_tensor_rank1_ones = GreaterEqual( torch.ones(5), torch.ones(5), ) -op_ge_rank2_rand = GreaterEqual( +op_ge_tensor_rank2_rand = GreaterEqual( torch.rand(4, 5), torch.rand(1, 5), ) -op_ge_rank3_randn = GreaterEqual( +op_ge_tensor_rank3_randn = GreaterEqual( torch.randn(10, 5, 2), torch.randn(10, 5, 2), ) -op_ge_rank4_randn = GreaterEqual( +op_ge_tensor_rank4_randn = GreaterEqual( torch.randn(3, 2, 2, 2), torch.randn(3, 2, 2, 2), ) -test_data_common = { - "ge_rank1_ones": op_ge_rank1_ones, - "ge_rank2_rand": op_ge_rank2_rand, - "ge_rank3_randn": op_ge_rank3_randn, - "ge_rank4_randn": op_ge_rank4_randn, +op_ge_scalar_rank1_ones = GreaterEqual(torch.ones(5), 1.0) +op_ge_scalar_rank2_rand = GreaterEqual(torch.rand(4, 5), 0.2) +op_ge_scalar_rank3_randn = GreaterEqual(torch.randn(10, 5, 2), -0.1) +op_ge_scalar_rank4_randn = GreaterEqual(torch.randn(3, 2, 2, 2), 0.3) + +test_data_tensor = { + "ge_tensor_rank1_ones": op_ge_tensor_rank1_ones, + "ge_tensor_rank2_rand": op_ge_tensor_rank2_rand, + "ge_tensor_rank3_randn": op_ge_tensor_rank3_randn, + "ge_tensor_rank4_randn": op_ge_tensor_rank4_randn, +} + +test_data_scalar = { + "ge_scalar_rank1_ones": op_ge_scalar_rank1_ones, + "ge_scalar_rank2_rand": op_ge_scalar_rank2_rand, + "ge_scalar_rank3_randn": op_ge_scalar_rank3_randn, + "ge_scalar_rank4_randn": op_ge_scalar_rank4_randn, } -@common.parametrize("test_module", test_data_common) -def test_ge_tosa_MI(test_module): +@common.parametrize("test_module", test_data_tensor) +def test_ge_tensor_tosa_MI(test_module): + pipeline = TosaPipelineMI[input_t]( + test_module, + test_module.get_inputs(), + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_ge_scalar_tosa_MI(test_module): pipeline = TosaPipelineMI[input_t]( - test_module, test_module.get_inputs(), aten_op, exir_op + test_module, + test_module.get_inputs(), + GreaterEqual.aten_op_scalar, + GreaterEqual.exir_op, ) pipeline.run() -@common.parametrize("test_module", test_data_common) -def test_ge_tosa_BI(test_module): +@common.parametrize("test_module", test_data_tensor) +def test_ge_tensor_tosa_BI(test_module): pipeline = TosaPipelineBI[input_t]( - test_module, test_module.get_inputs(), aten_op, exir_op + test_module, + test_module.get_inputs(), + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, ) pipeline.run() -@common.parametrize("test_module", test_data_common) -def test_ge_u55_BI(test_module): - # GREATER_EQUAL is not supported on U55. - pipeline = OpNotSupportedPipeline[input_t]( +@common.parametrize("test_module", test_data_scalar) +def test_ge_scalar_tosa_BI(test_module): + pipeline = TosaPipelineBI[input_t]( test_module, test_module.get_inputs(), - "TOSA-0.80+BI+u55", - {exir_op: 1}, + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, ) pipeline.run() -@common.parametrize("test_module", test_data_common) -def test_ge_u85_BI(test_module): - pipeline = EthosU85PipelineBI[input_t]( +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone300 +def test_ge_tensor_u55_BI(test_module): + # GREATER_EQUAL is not supported on U55. + pipeline = OpNotSupportedPipeline[input_t]( test_module, test_module.get_inputs(), - aten_op, - exir_op, - run_on_fvp=False, - use_to_edge_transform_and_lower=True, + "TOSA-0.80+BI+u55", + {GreaterEqual.exir_op: 1}, ) pipeline.run() -@common.parametrize("test_module", test_data_common) -@pytest.mark.skip(reason="The same as test_ge_u55_BI") -def test_ge_u55_BI_on_fvp(test_module): +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone300 +def test_ge_scalar_u55_BI(test_module): # GREATER_EQUAL is not supported on U55. pipeline = OpNotSupportedPipeline[input_t]( test_module, test_module.get_inputs(), "TOSA-0.80+BI+u55", - {exir_op: 1}, + {GreaterEqual.exir_op: 1}, + n_expected_delegates=1, + ) + pipeline.run() + + +@common.parametrize( + "test_module", + test_data_tensor, + xfails={"ge_tensor_rank4_randn": "MLETORCH-847: Boolean eq result unstable on U85"}, +) +@common.XfailIfNoCorstone320 +def test_ge_tensor_u85_BI(test_module): + pipeline = EthosU85PipelineBI[input_t]( + test_module, + test_module.get_inputs(), + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, + run_on_fvp=True, ) pipeline.run() @common.parametrize( "test_module", - test_data_common, - xfails={"ge_rank4_randn": "4D fails because boolean Tensors can't be subtracted"}, + test_data_scalar, + xfails={"ge_scalar_rank4_randn": "MLETORCH-847: Boolean eq result unstable on U85"}, ) -@common.SkipIfNoCorstone320 -def test_ge_u85_BI_on_fvp(test_module): +@common.XfailIfNoCorstone320 +def test_ge_scalar_u85_BI(test_module): pipeline = EthosU85PipelineBI[input_t]( test_module, test_module.get_inputs(), - aten_op, - exir_op, + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, run_on_fvp=True, - use_to_edge_transform_and_lower=True, ) pipeline.run()