Skip to content

Commit aedd502

Browse files
Merge branch 'main' into matmul_single_input
2 parents 9a2a9e7 + e88aafc commit aedd502

File tree

29 files changed

+579
-238
lines changed

29 files changed

+579
-238
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@
5959
)
6060

6161
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
62+
from executorch.backends.transforms.decompose_sdpa import (
63+
DecomposeScaledDotProductAttention,
64+
)
6265
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
6366
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
6467
from executorch.exir import ExportedProgram
@@ -194,6 +197,7 @@ def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
194197
)
195198

196199
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
200+
self.add_pass(DecomposeScaledDotProductAttention())
197201
self.add_pass(ReplaceScalarWithTensorArgPassTOSABI())
198202
self.add_pass(ScalarsToAttributePass())
199203
self.add_pass(DecomposeLayerNormPass())

backends/arm/_passes/decompose_softmax_pass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
from executorch.exir.pass_base import ExportPass
99

1010
# For BI case
11-
torch_softmax = (torch.ops.aten.softmax.int, torch.ops.aten.log_softmax.int)
11+
torch_softmax = (
12+
torch.ops.aten.softmax.int,
13+
torch.ops.aten._safe_softmax.default,
14+
torch.ops.aten.log_softmax.int,
15+
)
1216
# For MI case
1317
edge_softmax = (
1418
exir_ops.edge.aten._softmax.default,

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def is_node_supported(
194194
exir_ops.edge.aten.mul.Tensor,
195195
exir_ops.edge.aten.ne.Tensor,
196196
exir_ops.edge.aten.ne.Scalar,
197+
exir_ops.edge.aten.neg.default,
197198
exir_ops.edge.aten.add.Scalar,
198199
exir_ops.edge.aten.sub.Scalar,
199200
exir_ops.edge.aten.mul.Scalar,
@@ -311,6 +312,7 @@ class CheckProperQuantization(OperatorSupportBase):
311312
exir_ops.edge.aten.max_pool2d_with_indices.default,
312313
exir_ops.edge.aten.mm.default,
313314
exir_ops.edge.aten.mul.Tensor,
315+
exir_ops.edge.aten.neg.default,
314316
exir_ops.edge.aten.relu.default,
315317
exir_ops.edge.aten.sub.Tensor,
316318
exir_ops.edge.aten.upsample_bilinear2d.vec,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
op_maximum,
3232
op_minimum,
3333
op_mul,
34+
op_neg,
3435
op_permute,
3536
op_pow,
3637
op_reciprocal,

backends/arm/operators/op_neg.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
from typing import List
8+
9+
import torch.fx
10+
11+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
12+
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
13+
get_input_qparams,
14+
get_output_qparams,
15+
)
16+
from executorch.backends.arm.operators.node_visitor import (
17+
NodeVisitor,
18+
register_node_visitor,
19+
)
20+
21+
from executorch.backends.arm.tosa_mapping import TosaArg
22+
23+
24+
def get_negate_zero_points(node: torch.fx.Node, dtype: ts.DType) -> tuple[int, int]:
25+
"""
26+
Returns (input1_zp, output_zp) for TOSA NEGATE.
27+
Must be zero for non-int8 types.
28+
"""
29+
if dtype == ts.DType.INT8:
30+
return (
31+
get_input_qparams(node)[0].zp,
32+
get_output_qparams(node)[0].zp,
33+
)
34+
return (0, 0)
35+
36+
37+
@register_node_visitor
38+
class NegVisitor(NodeVisitor):
39+
target = "aten.neg.default"
40+
41+
supported_dtypes = {
42+
ts.DType.INT8,
43+
ts.DType.INT16,
44+
ts.DType.INT32,
45+
ts.DType.FP16,
46+
ts.DType.BF16,
47+
ts.DType.FP32,
48+
}
49+
50+
def __init__(self, *args):
51+
super().__init__(*args)
52+
53+
def define_node(
54+
self,
55+
node: torch.fx.Node,
56+
tosa_graph: ts.TosaSerializer,
57+
inputs: List[TosaArg],
58+
output: TosaArg,
59+
) -> None:
60+
61+
if inputs[0].dtype not in self.supported_dtypes:
62+
raise ValueError(f"Unsupported dtype for NEGATE: {inputs[0].dtype}")
63+
64+
if inputs[0].dtype != output.dtype:
65+
raise ValueError(
66+
"All inputs and output need same dtype."
67+
f"Got {inputs[0].dtype=}, {output.dtype=}"
68+
)
69+
input_zp, output_zp = get_negate_zero_points(node, inputs[0].dtype)
70+
71+
attr = ts.TosaSerializerAttribute()
72+
attr.NegateAttribute(input1_zp=input_zp, output_zp=output_zp)
73+
tosa_graph.addOperator(
74+
ts.TosaOp.Op().NEGATE,
75+
[inputs[0].name],
76+
[output.name],
77+
attributes=attr,
78+
)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,9 @@ def any_or_hardtanh_min_zero(n: Node):
375375
)
376376
]
377377
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
378+
elif node.target in (torch.ops.aten.neg.default,):
379+
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
380+
quant_properties.quant_output = _QuantProperty(0, input_act_qspec)
378381
elif node.target in _one_to_one:
379382
quant_properties.quant_inputs = [_QuantProperty(0, input_act_qspec)]
380383
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)

backends/arm/test/models/test_conformer.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def test_conformer_tosa_BI(self):
8383
)
8484
)
8585

86-
@unittest.expectedFailure # TODO(MLETORCH-635)
8786
def test_conformer_u55_BI(self):
8887
tester = (
8988
ArmTester(
@@ -97,13 +96,20 @@ def test_conformer_u55_BI(self):
9796
.to_executorch()
9897
.serialize()
9998
)
99+
100100
if conftest.is_option_enabled("corstone_fvp"):
101-
tester.run_method_and_compare_outputs(
102-
qtol=1.0,
103-
rtol=1.0,
104-
atol=5.0,
105-
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
106-
)
101+
try:
102+
tester.run_method_and_compare_outputs(
103+
qtol=1.0,
104+
rtol=1.0,
105+
atol=5.0,
106+
inputs=get_test_inputs(self.dim, self.lengths, self.num_examples),
107+
)
108+
self.fail(
109+
"TODO(MLETORCH-635): Expected failure under FVP option, but test passed."
110+
)
111+
except Exception:
112+
pass
107113

108114
@unittest.expectedFailure # TODO(MLETORCH-635)
109115
def test_conformer_u85_BI(self):

backends/arm/test/ops/test_neg.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Dict, Tuple
8+
9+
import torch
10+
from executorch.backends.arm.test import common
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
EthosU55PipelineBI,
13+
EthosU85PipelineBI,
14+
TosaPipelineBI,
15+
TosaPipelineMI,
16+
)
17+
18+
input_t1 = Tuple[torch.Tensor]
19+
20+
21+
class Neg(torch.nn.Module):
22+
23+
aten_op = "torch.ops.aten.neg.default"
24+
exir_op = "executorch_exir_dialects_edge__ops_aten_neg_default"
25+
26+
test_data: Dict[str, input_t1] = {
27+
"rank_1_ramp": (torch.arange(-16, 16, 0.2),),
28+
"rank_2_rand_uniform": (torch.rand(10, 10) - 0.5,),
29+
"rank_3_all_ones": (torch.ones(10, 10, 10),),
30+
"rank_4_all_zeros": (torch.zeros(1, 10, 10, 10),),
31+
"rank_4_randn_pos": (torch.randn(1, 4, 4, 4) + 10,),
32+
"rank_4_randn_neg": (torch.randn(1, 4, 4, 4) - 10,),
33+
}
34+
35+
def forward(self, x: torch.Tensor):
36+
return torch.neg(x)
37+
38+
39+
@common.parametrize("test_data", Neg.test_data)
40+
def test_neg_tosa_MI(test_data: input_t1):
41+
pipeline = TosaPipelineMI[input_t1](Neg(), test_data, Neg.aten_op, Neg.exir_op)
42+
pipeline.run()
43+
44+
45+
@common.parametrize("test_data", Neg.test_data)
46+
def test_neg_tosa_BI(test_data: input_t1):
47+
pipeline = TosaPipelineBI[input_t1](Neg(), test_data, Neg.aten_op, Neg.exir_op)
48+
pipeline.run()
49+
50+
51+
@common.parametrize("test_data", Neg.test_data)
52+
@common.XfailIfNoCorstone300
53+
def test_neg_u55_BI(test_data: input_t1):
54+
pipeline = EthosU55PipelineBI[input_t1](
55+
Neg(), test_data, Neg.aten_op, Neg.exir_op, run_on_fvp=True
56+
)
57+
pipeline.run()
58+
59+
60+
@common.parametrize("test_data", Neg.test_data)
61+
@common.XfailIfNoCorstone320
62+
def test_neg_u85_BI(test_data: input_t1):
63+
pipeline = EthosU85PipelineBI[input_t1](
64+
Neg(), test_data, Neg.aten_op, Neg.exir_op, run_on_fvp=True
65+
)
66+
pipeline.run()

backends/arm/test/ops/test_sdpa.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
from typing import Tuple
8+
9+
import torch
10+
11+
from executorch.backends.arm.test.tester.test_pipeline import (
12+
TosaPipelineBI,
13+
TosaPipelineMI,
14+
)
15+
16+
17+
class SDPA(torch.nn.Module):
18+
def __init__(self):
19+
super().__init__()
20+
21+
def forward(self, query, key, value):
22+
return torch.nn.functional.scaled_dot_product_attention(
23+
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
24+
)
25+
26+
27+
input_t = Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
28+
29+
30+
def test_sdpa_MI():
31+
test_input = tuple(torch.randn(1, 3, 197, 64) for x in range(3))
32+
pipeline = TosaPipelineMI[input_t](SDPA(), test_input, [], [])
33+
pipeline.pop_stage("check_count.exir")
34+
pipeline.run()
35+
36+
37+
def test_sdpa_BI():
38+
test_input = tuple(torch.randn(1, 3, 197, 64) for x in range(3))
39+
pipeline = TosaPipelineBI[input_t](SDPA(), test_input, [], [])
40+
pipeline.pop_stage("check.quant_nodes")
41+
pipeline.pop_stage("check_count.exir")
42+
pipeline.pop_stage(
43+
"run_method_and_compare_outputs"
44+
) # TODO: reference is not quantized
45+
pipeline.run()

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ python_unittest(
347347
":compiler",
348348
"//caffe2:torch",
349349
"//executorch/backends/cadence/aot:compiler",
350+
"//executorch/backends/cadence/aot:graph_builder",
350351
"//executorch/backends/cadence/aot:ops_registrations",
351352
"//executorch/backends/cadence/aot:pass_utils",
352353
"//executorch/backends/cadence/aot:remove_ops",

0 commit comments

Comments
 (0)