Skip to content
1 change: 1 addition & 0 deletions backends/arm/operator_support/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
pool_2d_support,
reduce_sum_support,
right_shift_support,
slice_copy_support,
to_copy_support,
tosa_supported_operators,
)
39 changes: 39 additions & 0 deletions backends/arm/operator_support/slice_copy_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import logging

import torch.fx as fx
from executorch.backends.arm.operator_support.tosa_supported_operators import (
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import getNodeArgs
from executorch.exir.dialects._ops import ops as exir_ops

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)


@register_tosa_support_check
class SliceCopySupported(SupportedTOSAOperatorCheck):
targets = [exir_ops.edge.aten.slice_copy.Tensor]

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: # type: ignore[override, misc]
if tosa_spec not in self.tosa_specs:
return False

inputs = getNodeArgs(node)
if len(inputs) == 5 and (step := inputs[4].number) != 1:
logging.warning(f"{node.target} with step size of {step} not supported.")
return False
return True
2 changes: 0 additions & 2 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
def get_registered_tosa_support_checks(
tosa_spec: TosaSpecification,
) -> list[Type[SupportedTOSAOperatorCheck]]:

if tosa_spec not in _tosa_spec_support:
raise RuntimeError(
f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support.keys())}"
Expand Down Expand Up @@ -155,7 +154,6 @@ def is_node_supported(
exir_ops.edge.aten._softmax.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten._log_softmax.default,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
Expand Down
19 changes: 14 additions & 5 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ def define_node(
# Handle int8 (quantized) and int32
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]

dim_order = (
inputs[0].dim_order
if len(inputs[0].shape) > len(inputs[1].shape)
else inputs[1].dim_order
)

if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
tosa_graph, inputs, node
Expand All @@ -61,13 +67,14 @@ def define_node(
# output.dtype == ts.DType.INT32
add_output = output

input1, input2 = tutils.reshape_for_broadcast(
tosa_graph, rescaled_inputs, dim_order
)

# Do the INT32 Add
tosa_graph.addOperator(
TosaOp.Op().ADD,
[
rescaled_inputs[0].name,
rescaled_inputs[1].name,
],
[input1.name, input2.name],
[add_output.name],
None,
)
Expand Down Expand Up @@ -108,10 +115,12 @@ def define_node(
assert inputs[0].dtype == ts.DType.FP32
assert output.dtype == ts.DType.FP32

input1, input2 = tutils.reshape_for_broadcast(tosa_graph, inputs)

# MI lowering
tosa_graph.addOperator(
TosaOp.Op().ADD,
[inputs[0].name, inputs[1].name],
[input1.name, input2.name],
[output.name],
None,
)
26 changes: 21 additions & 5 deletions backends/arm/operators/op_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import reshape_for_broadcast
from serializer.tosa_serializer import TosaOp


Expand All @@ -43,6 +44,12 @@ def define_node(
output: TosaArg,
) -> None:
assert inputs[0].dtype == inputs[1].dtype == output.dtype == ts.DType.INT8

dim_order = (
inputs[0].dim_order
if len(inputs[0].shape) > len(inputs[1].shape)
else inputs[1].dim_order
)
input_A = inputs[0]
input_B = inputs[1]
input_qparams = get_input_qparams(node) # pyre-ignore[16]
Expand All @@ -68,15 +75,21 @@ def define_node(
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)

input1, input2 = tutils.reshape_for_broadcast(
tosa_graph,
[
input_A_rescaled,
input_B_rescaled,
],
dim_order,
)

# Do the INT32 Mul
attr = ts.TosaSerializerAttribute()
attr.MulAttribute(shift=0)
tosa_graph.addOperator(
TosaOp.Op().MUL,
[
input_A_rescaled.name,
input_B_rescaled.name,
],
[input1.name, input2.name],
[mul_output.name],
attr,
)
Expand All @@ -101,8 +114,11 @@ def define_node(
) -> None:
if inputs[0].dtype == ts.DType.INT8:
return super().define_node(node, tosa_graph, inputs, output)

input1, input2 = reshape_for_broadcast(tosa_graph, inputs)

attr = ts.TosaSerializerAttribute()
attr.MulAttribute(shift=0)
tosa_graph.addOperator(
TosaOp.Op().MUL, [inputs[0].name, inputs[1].name], [output.name], attr
TosaOp.Op().MUL, [input1.name, input2.name], [output.name], attr
)
6 changes: 4 additions & 2 deletions backends/arm/operators/op_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ def define_node(
output: TosaArg,
) -> None:

# See slice_copy_support.py
assert len(inputs) == 4 or (len(inputs) == 5 and inputs[4].number == 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: add an error message

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


# aten.slice_copy supports slicing in 1d at a time.
# The arguments are dimension of slicing, start index and end index.
assert len(inputs) == 4
# The arguments are the actual input, dimension of slicing, start index, end index and optinal step or stride.
input_node, dim, start, end = inputs

# Translate and check parameters in Pytorch dim order.
Expand Down
7 changes: 6 additions & 1 deletion backends/arm/test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def pytest_configure(config):
)
# Only enable if we also have the TOSA reference model available.
pytest._test_options["corstone_fvp"] = True # type: ignore[attr-defined]

pytest._test_options["llama_inputs"] = config.option.llama_inputs # type: ignore[attr-defined]
pytest._test_options["fast_fvp"] = False # type: ignore[attr-defined]
if getattr(config.option, "fast_fvp", False):
pytest._test_options["fast_fvp"] = config.option.fast_fvp # type: ignore[attr-defined]
Expand All @@ -69,6 +69,11 @@ def try_addoption(*args, **kwargs):
try_addoption("--arm_quantize_io", action="store_true", help="Deprecated.")
try_addoption("--arm_run_corstoneFVP", action="store_true", help="Deprecated.")
try_addoption("--fast_fvp", action="store_true")
try_addoption(
"--llama_inputs",
nargs="+",
help="List of two files. Firstly .pt file. Secondly .json",
)


def pytest_sessionstart(session):
Expand Down
125 changes: 125 additions & 0 deletions backends/arm/test/models/test_llama_arm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will take a look but just a nit
s/test_llama_arm.py/test_llama.py/

Copy link
Collaborator

@zingo zingo Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will take a look ...

Merge time? :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to test_llama.py

# All rights reserved.
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging

import os
import sys
import unittest

import torch

from executorch.backends.arm.test import common, conftest
from executorch.backends.arm.test.tester.arm_tester import ArmTester
from executorch.examples.models.llama.export_llama_lib import (
build_args_parser,
get_llama_model,
)

from executorch.exir import EdgeCompileConfig

# Add project dir to sys path to workaround importlib.import_module() conditions in model_factory.py
this_files_dir = os.path.dirname(os.path.abspath(__file__))
project_dir = os.path.abspath(os.path.join(this_files_dir, "../../../.."))
sys.path.append(project_dir)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class TestLlama(unittest.TestCase):
"""
Test class of Llama models. Type of Llama model depends on command line parameters:
--llama_inputs <path to .pt file> <path to json file>
Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json
"""

_edge_compile_config: EdgeCompileConfig = EdgeCompileConfig(
_check_ir_validity=False,
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
)

def prepare_model(self):

checkpoint = None
params_file = None
if conftest.is_option_enabled("llama_inputs"):
param_list = conftest.get_option("llama_inputs")
assert (
isinstance(param_list, list) and len(param_list) == 2
), "invalid number of inputs for --llama_inputs"
checkpoint = param_list[0]
params_file = param_list[1]
assert isinstance(checkpoint, str) and isinstance(
params_file, str
), "invalid input for --llama_inputs"
else:
logging.warning(
"Skipping Llama test because of lack of input. To run use --llama_inputs <.pt> <.json>"
)
return

assert os.path.isfile(checkpoint) and os.path.isfile(
params_file
), "Invalid file paths"

# TODO: Enable key value cache
args = [
"--disable_dynamic_shape",
"-c",
checkpoint,
"-p",
params_file,
"--model",
"stories110m",
]
parser = build_args_parser()
args = parser.parse_args(args)

llama_model, llama_inputs, llama_meta = get_llama_model(args)

# TODO: Remove workaround since attention mask should not be persistent,
# it only works if input shape is always the same
freqs_c = "freqs_cos"
freqs_s = "freqs_sin"
for i in range(llama_model.n_layers):
val = llama_model.layers[i].attention.get_buffer("mask")
llama_model.layers[i].attention.register_buffer(
"mask", val, persistent=True
)
val = llama_model.layers[i].attention.rope.get_buffer(freqs_c)
llama_model.layers[i].attention.rope.register_buffer(
freqs_c, val, persistent=True
)
val = llama_model.layers[i].attention.rope.get_buffer(freqs_s)
llama_model.layers[i].attention.rope.register_buffer(
freqs_s, val, persistent=True
)

return llama_model, llama_inputs, llama_meta

def test_llama_tosa_MI(self):
llama_model, llama_inputs, llama_meta = self.prepare_model()

with torch.no_grad():
(
ArmTester(
llama_model,
example_inputs=llama_inputs,
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
constant_methods=llama_meta,
)
.export()
.to_edge_transform_and_lower(
edge_compile_config=self._edge_compile_config
)
.check_count({"torch.ops.higher_order.executorch_call_delegate": 14})
.to_executorch()
.run_method_and_compare_outputs(
inputs=llama_inputs, atol=1.8, rtol=0.01
)
)
24 changes: 23 additions & 1 deletion backends/arm/test/ops/test_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Tuple

import torch
Expand Down Expand Up @@ -61,6 +60,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
}


class Add3(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return x + y

test_data: list[input_t2] = {
"3d_randn_diff_rank": (torch.randn(1, 4, 5), torch.randn(4, 1)),
"4d_randn_diff_rank": (torch.randn(1, 1, 4, 4), torch.randn(4, 1)),
"4d_randn_diff_rank_2": (torch.randn(4, 1), torch.randn(1, 1, 4, 5)),
}


@common.parametrize("test_data", Add.test_data)
def test_add_tosa_MI(test_data: input_t1):
pipeline = TosaPipelineMI[input_t1](Add(), test_data, aten_op, exir_op)
Expand Down Expand Up @@ -145,6 +155,18 @@ def test_add2_tosa_MI(test_data: input_t2):
pipeline.run()


@common.parametrize("test_data", Add3.test_data)
def test_add3_tosa_MI(test_data: input_t2):
pipeline = TosaPipelineMI[input_t2](Add3(), test_data, aten_op, exir_op)
pipeline.run()


@common.parametrize("test_data", Add3.test_data)
def test_add3_tosa_BI(test_data: input_t2):
pipeline = TosaPipelineBI[input_t2](Add3(), test_data, aten_op, exir_op)
pipeline.run()


@common.parametrize("test_data", Add2.test_data)
def test_add2_tosa_BI(test_data: input_t2):
pipeline = TosaPipelineBI[input_t2](Add2(), test_data, aten_op, exir_op)
Expand Down
Loading
Loading