diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 640bbf3f9be..210f006d72d 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -59,6 +59,9 @@ ) from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification +from executorch.backends.transforms.decompose_sdpa import ( + DecomposeScaledDotProductAttention, +) from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass from executorch.exir import ExportedProgram @@ -194,6 +197,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram): ) def transform_for_annotation_pipeline(self, graph_module: GraphModule): + self.add_pass(DecomposeScaledDotProductAttention()) self.add_pass(ReplaceScalarWithTensorArgPassTOSABI()) self.add_pass(ScalarsToAttributePass()) self.add_pass(DecomposeLayerNormPass()) diff --git a/backends/arm/_passes/decompose_softmax_pass.py b/backends/arm/_passes/decompose_softmax_pass.py index 7e8591eb386..a735501f711 100644 --- a/backends/arm/_passes/decompose_softmax_pass.py +++ b/backends/arm/_passes/decompose_softmax_pass.py @@ -8,7 +8,11 @@ from executorch.exir.pass_base import ExportPass # For BI case -torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int) +torch_softmax = ( + torch.ops.aten.softmax.int, + torch.ops.aten._safe_softmax.default, + torch.ops.aten.log_softmax.int, +) # For MI case edge_softmax = ( exir_ops.edge.aten._softmax.default, diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py index fb0d5eb75d3..3aa953bf602 100644 --- a/backends/arm/test/models/test_conformer.py +++ b/backends/arm/test/models/test_conformer.py @@ -83,7 +83,6 @@ def test_conformer_tosa_BI(self): ) ) - @unittest.expectedFailure # TODO(MLETORCH-635) def test_conformer_u55_BI(self): tester = ( ArmTester( @@ -97,13 +96,20 @@ def test_conformer_u55_BI(self): .to_executorch() .serialize() ) + if conftest.is_option_enabled("corstone_fvp"): - tester.run_method_and_compare_outputs( - qtol=1.0, - rtol=1.0, - atol=5.0, - inputs=get_test_inputs(self.dim, self.lengths, self.num_examples), - ) + try: + tester.run_method_and_compare_outputs( + qtol=1.0, + rtol=1.0, + atol=5.0, + inputs=get_test_inputs(self.dim, self.lengths, self.num_examples), + ) + self.fail( + "TODO(MLETORCH-635): Expected failure under FVP option, but test passed." + ) + except Exception: + pass @unittest.expectedFailure # TODO(MLETORCH-635) def test_conformer_u85_BI(self): diff --git a/backends/arm/test/ops/test_sdpa.py b/backends/arm/test/ops/test_sdpa.py new file mode 100644 index 00000000000..470030f67fd --- /dev/null +++ b/backends/arm/test/ops/test_sdpa.py @@ -0,0 +1,45 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Tuple + +import torch + +from executorch.backends.arm.test.tester.test_pipeline import ( + TosaPipelineBI, + TosaPipelineMI, +) + + +class SDPA(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, query, key, value): + return torch.nn.functional.scaled_dot_product_attention( + query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False + ) + + +input_t = Tuple[torch.Tensor, torch.Tensor, torch.Tensor] + + +def test_sdpa_MI(): + test_input = tuple(torch.randn(1, 3, 197, 64) for x in range(3)) + pipeline = TosaPipelineMI[input_t](SDPA(), test_input, [], []) + pipeline.pop_stage("check_count.exir") + pipeline.run() + + +def test_sdpa_BI(): + test_input = tuple(torch.randn(1, 3, 197, 64) for x in range(3)) + pipeline = TosaPipelineBI[input_t](SDPA(), test_input, [], []) + pipeline.pop_stage("check.quant_nodes") + pipeline.pop_stage("check_count.exir") + pipeline.pop_stage( + "run_method_and_compare_outputs" + ) # TODO: reference is not quantized + pipeline.run()