Skip to content
Closed
10 changes: 10 additions & 0 deletions backends/arm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .arm_backend import ArmCompileSpecBuilder # noqa # usort: skip
from .tosa_backend import TOSABackend # noqa # usort: skip
from .tosa_partitioner import TOSAPartitioner # noqa # usort: skip
from .ethosu_backend import EthosUBackend # noqa # usort: skip
from .ethosu_partitioner import EthosUPartitioner # noqa # usort: skip
14 changes: 12 additions & 2 deletions backends/arm/_passes/convert_expand_copy_to_repeat.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import logging
from typing import cast

from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

logger = logging.getLogger(__name__)


class ConvertExpandCopyToRepeatPass(ExportPass):
"""
Expand Down Expand Up @@ -41,6 +43,14 @@ def call_operator(self, op, args, kwargs, meta):
multiples[i] if multiples[i] != -1 and extended_shape[i] == 1 else 1
for i in range(expanded_rank)
]

if all((x == 1 for x in multiples)):
# All dimensions/repetitions occur only once. Remove node
# altogether since it's in practice just a copy.
logger.warning("Found redundant expand node (no-op). Removing it.")

return args[0]

return super().call_operator(
op=self.repeat, args=(args[0], multiples), kwargs=kwargs, meta=meta
)
4 changes: 2 additions & 2 deletions backends/arm/ethosu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
import logging
from typing import final, List

from executorch.backends.arm.arm_vela import vela_compile
from executorch.backends.arm import TOSABackend

from executorch.backends.arm.tosa_backend import TOSABackend
from executorch.backends.arm.arm_vela import vela_compile
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.compile_spec_schema import CompileSpec
from torch.export.exported_program import ExportedProgram
Expand Down
3 changes: 1 addition & 2 deletions backends/arm/ethosu_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from executorch.backends.arm.arm_backend import (
is_ethosu,
) # usort: skip
from executorch.backends.arm.ethosu_backend import EthosUBackend
from executorch.backends.arm.tosa_partitioner import TOSAPartitioner
from executorch.backends.arm import EthosUBackend, TOSAPartitioner
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.partitioner import DelegationSpec
from torch.fx.passes.operator_support import OperatorSupportBase
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ def is_node_supported(
exir_ops.edge.aten._log_softmax.default,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.upsample_bilinear2d.vec,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.var.correction,
exir_ops.edge.aten.var.dim,
Expand Down Expand Up @@ -365,6 +366,7 @@ def is_node_supported(
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.upsample_bilinear2d.vec,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.gelu.default,
):
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
op_to_copy,
op_to_dim_order_copy,
op_transpose,
op_upsample_bilinear2d,
op_upsample_nearest2d,
op_view,
op_where,
Expand Down
10 changes: 9 additions & 1 deletion backends/arm/operators/op_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,15 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:
assert inputs[0].dtype == inputs[1].dtype == output.dtype == ts.DType.INT8
if (
inputs[0].dtype != ts.DType.INT8
or inputs[1].dtype != ts.DType.INT8
or output.dtype != ts.DType.INT8
):
raise ValueError(
f"Inputs and output for {self.target} need to be INT8, got "
f"{inputs[0].dtype=}, {inputs[1].dtype=} and {output.dtype=}"
)

dim_order = (
inputs[0].dim_order
Expand Down
100 changes: 100 additions & 0 deletions backends/arm/operators/op_upsample_bilinear2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import torch

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import build_rescale
from executorch.backends.arm.tosa_utils import get_resize_parameters, tosa_shape
from tosa_tools.v0_80.tosa.ResizeMode import ResizeMode # type: ignore


@register_node_visitor
class UpsampleBilinear2dVisitor_0_80(NodeVisitor):
target = "aten.upsample_bilinear2d.vec"

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
assert (
inputs[0].shape is not None and output.shape is not None
), "Only static shapes are supported"

input_dtype = inputs[0].dtype

# tosa_shape output is NHWC, take HW
input_size_yx = torch.tensor(
tosa_shape(inputs[0].shape, inputs[0].dim_order)[1:3]
)
# Ignore scale and size parameters, directly use the output size as
# we only support static shapes currently
output_size_yx = torch.tensor(tosa_shape(output.shape, output.dim_order)[1:3])

scale_n_yx, scale_d_yx, offset_yx, border_yx = get_resize_parameters(
input_size_yx, output_size_yx, ResizeMode.NEAREST, align_corners=True
)

def in_int16_range(x):
return torch.all(x >= -(2**15)) and torch.all(x <= 2**15 - 1)

assert in_int16_range(scale_n_yx)
assert in_int16_range(scale_d_yx)
assert in_int16_range(border_yx)

attr = ts.TosaSerializerAttribute()
attr.ResizeAttribute(
scale=[scale_n_yx[0], scale_d_yx[0], scale_n_yx[1], scale_d_yx[1]],
offset=offset_yx.tolist(),
border=border_yx.tolist(),
mode=ResizeMode.BILINEAR,
)

if input_dtype == output.dtype == ts.DType.FP32:
tosa_graph.addOperator(
ts.TosaOp.Op().RESIZE, [inputs[0].name], [output.name], attr
)
return
elif input_dtype == output.dtype == ts.DType.INT8:
intermediate = tosa_graph.addIntermediate(
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
)

tosa_graph.addOperator(
ts.TosaOp.Op().RESIZE, [inputs[0].name], [intermediate.name], attr
)

final_output_scale = float(1 / (scale_n_yx[0] * scale_n_yx[1]))

build_rescale(
tosa_fb=tosa_graph,
scale=[final_output_scale],
input_node=intermediate,
output_name=output.name,
output_type=ts.DType.INT8,
output_shape=output.shape,
input_zp=0,
output_zp=0,
is_double_round=False,
)
else:
raise ValueError(
"Input/output dtype not in {float32, int8}: {input_dtype=} {output.dtype=}"
)
10 changes: 9 additions & 1 deletion backends/arm/quantizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from .quantization_config import QuantizationConfig # noqa # usort: skip
from .arm_quantizer import ( # noqa
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
)
3 changes: 1 addition & 2 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@
import torch
from executorch.backends.arm._passes import ArmPassManager

from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer import arm_quantizer_utils, QuantizationConfig
from executorch.backends.arm.quantizer.arm_quantizer_utils import ( # type: ignore[attr-defined]
mark_node_as_annotated,
)
from executorch.backends.arm.quantizer.quantization_annotator import ( # type: ignore[import-not-found]
annotate_graph,
)

from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.arm_backend import (
get_tosa_spec,
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

import torch
import torch.fx
from executorch.backends.arm.quantizer import arm_quantizer_utils
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from executorch.backends.arm.quantizer import arm_quantizer_utils, QuantizationConfig
from executorch.backends.arm.tosa_utils import get_node_debug_info
from torch.ao.quantization.quantizer import QuantizationSpecBase, SharedQuantizationSpec
from torch.ao.quantization.quantizer.utils import (
Expand Down Expand Up @@ -215,6 +214,7 @@ def _match_pattern(
torch.ops.aten.flip.default,
torch.ops.aten.chunk.default,
torch.ops.aten.contiguous.default,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_nearest2d.vec,
torch.ops.aten.pad.default,
torch.ops.aten.amax.default,
Expand Down
35 changes: 26 additions & 9 deletions backends/arm/test/ops/test_expand.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -16,7 +15,7 @@

import torch

from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand All @@ -37,14 +36,14 @@ class Expand(torch.nn.Module):
# (input tensor, multiples)
test_parameters = [
(torch.rand(1), (2,)),
(torch.randn(1, 4), (1, -1)),
(torch.randn(1), (2, 2, 4)),
(torch.randn(1, 1, 1, 5), (1, 4, -1, -1)),
(torch.randn(1, 1, 192), (1, -1, -1)),
(torch.randn(1, 1), (1, 2, 2, 4)),
(torch.randn(1, 1), (2, 2, 2, 4)),
(torch.randn(10, 1, 1, 97), (-1, 4, -1, -1)),
(torch.rand(1, 1, 2, 2), (4, 3, -1, 2)),
(torch.randn(1, 4), (1, -1)),
(torch.randn(1, 1, 192), (1, -1, -1)),
]

def forward(self, x: torch.Tensor, m: Sequence):
Expand Down Expand Up @@ -117,34 +116,52 @@ def test_expand_tosa_MI(self, test_input, multiples):
def test_expand_tosa_BI(self, test_input, multiples):
self._test_expand_tosa_BI_pipeline(self.Expand(), (test_input, multiples))

@parameterized.expand(Expand.test_parameters[:-3])
@parameterized.expand(Expand.test_parameters[:-5])
@pytest.mark.corstone_fvp
def test_expand_u55_BI(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
)

# MLETORCH-629: Expand does not work on FVP with batch>1
@parameterized.expand(Expand.test_parameters[-3:])
@parameterized.expand(Expand.test_parameters[-5:-2])
@pytest.mark.corstone_fvp
@conftest.expectedFailureOnFVP
def test_expand_u55_BI_xfails_on_fvp(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
)

@parameterized.expand(Expand.test_parameters[-2:])
@pytest.mark.xfail(
reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs"
)
def test_expand_u55_BI_xfails(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
common.get_u55_compile_spec(), self.Expand(), (test_input, multiples)
)

@parameterized.expand(Expand.test_parameters[:-3])
@parameterized.expand(Expand.test_parameters[:-5])
@pytest.mark.corstone_fvp
def test_expand_u85_BI(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
)

# MLETORCH-629: Expand does not work on FVP with batch>1
@parameterized.expand(Expand.test_parameters[-3:])
@parameterized.expand(Expand.test_parameters[-5:-2])
@pytest.mark.corstone_fvp
@conftest.expectedFailureOnFVP
def test_expand_u85_BI_xfails(self, test_input, multiples):
def test_expand_u85_BI_xfails_on_fvp(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
)

@parameterized.expand(Expand.test_parameters[-2:])
@pytest.mark.xfail(
reason="MLETORCH-716: Node will be optimized away and Vela can't handle empty graphs"
)
def test_expand_u85_xfails(self, test_input, multiples):
self._test_expand_ethosu_BI_pipeline(
common.get_u85_compile_spec(), self.Expand(), (test_input, multiples)
)
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_hardtanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch

from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_max_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest

import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch

from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_relu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Tuple

import torch
from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/test/ops/test_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch

from executorch.backends.arm.quantizer.arm_quantizer import (
from executorch.backends.arm.quantizer import (
EthosUQuantizer,
get_symmetric_quantization_config,
TOSAQuantizer,
Expand Down
Loading
Loading