Skip to content

Commit 674b125

Browse files
committed
Arm: Add initial Llama model test case
Adds Llama model test case for TOSA-0.80+MI. Handles Add and Mul where inputs have different ranks. New unit test parameters --llama_inputs added, without it test will be skipped. Tested with smaller stories, see examples/models/llama/UTILS.md. Adds get_llama_model() to export_llama_lib used in test case. Change-Id: I003bbcee8f0cc35193d793a4af9b031453114e71
1 parent 08eb5be commit 674b125

File tree

13 files changed

+366
-23
lines changed

13 files changed

+366
-23
lines changed

backends/arm/operator_support/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
pool_2d_support,
1111
reduce_sum_support,
1212
right_shift_support,
13+
slice_copy_support,
1314
to_copy_support,
1415
tosa_supported_operators,
1516
)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
7+
import logging
8+
9+
import torch.fx as fx
10+
from executorch.backends.arm.operator_support.tosa_supported_operators import (
11+
register_tosa_support_check,
12+
SupportedTOSAOperatorCheck,
13+
)
14+
from executorch.backends.arm.tosa_specification import TosaSpecification
15+
from executorch.backends.arm.tosa_utils import getNodeArgs
16+
from executorch.exir.dialects._ops import ops as exir_ops
17+
18+
logger = logging.getLogger(__name__)
19+
logger.setLevel(logging.WARNING)
20+
21+
22+
@register_tosa_support_check
23+
class SliceCopySupported(SupportedTOSAOperatorCheck):
24+
targets = [exir_ops.edge.aten.slice_copy.Tensor]
25+
26+
tosa_specs = [
27+
TosaSpecification.create_from_string("TOSA-0.80+BI"),
28+
TosaSpecification.create_from_string("TOSA-0.80+MI"),
29+
]
30+
31+
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: # type: ignore[override, misc]
32+
if tosa_spec not in self.tosa_specs:
33+
return False
34+
35+
inputs = getNodeArgs(node)
36+
if len(inputs) == 5 and (step := inputs[4].number) != 1:
37+
logging.warning(f"{node.target} with step size of {step} not supported.")
38+
return False
39+
return True

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]):
6262
def get_registered_tosa_support_checks(
6363
tosa_spec: TosaSpecification,
6464
) -> list[Type[SupportedTOSAOperatorCheck]]:
65-
6665
if tosa_spec not in _tosa_spec_support:
6766
raise RuntimeError
6867

@@ -125,7 +124,6 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
125124
exir_ops.edge.aten._softmax.default,
126125
exir_ops.edge.aten.select_copy.int,
127126
exir_ops.edge.aten._log_softmax.default,
128-
exir_ops.edge.aten.slice_copy.Tensor,
129127
exir_ops.edge.aten.sub.Tensor,
130128
exir_ops.edge.aten.tanh.default,
131129
exir_ops.edge.aten.upsample_nearest2d.vec,

backends/arm/operators/op_add.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ def define_node(
4545
# Handle int8 (quantized) and int32
4646
assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]
4747

48+
dim_order = (
49+
inputs[0].dim_order
50+
if len(inputs[0].shape) > len(inputs[1].shape)
51+
else inputs[1].dim_order
52+
)
53+
4854
if inputs[0].dtype == ts.DType.INT8:
4955
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
5056
tosa_graph, inputs, node
@@ -61,13 +67,14 @@ def define_node(
6167
# output.dtype == ts.DType.INT32
6268
add_output = output
6369

70+
input1, input2 = tutils.reshape_for_broadcast(
71+
tosa_graph, rescaled_inputs, dim_order
72+
)
73+
6474
# Do the INT32 Add
6575
tosa_graph.addOperator(
6676
TosaOp.Op().ADD,
67-
[
68-
rescaled_inputs[0].name,
69-
rescaled_inputs[1].name,
70-
],
77+
[input1.name, input2.name],
7178
[add_output.name],
7279
None,
7380
)
@@ -108,10 +115,12 @@ def define_node(
108115
assert inputs[0].dtype == ts.DType.FP32
109116
assert output.dtype == ts.DType.FP32
110117

118+
input1, input2 = tutils.reshape_for_broadcast(tosa_graph, inputs)
119+
111120
# MI lowering
112121
tosa_graph.addOperator(
113122
TosaOp.Op().ADD,
114-
[inputs[0].name, inputs[1].name],
123+
[input1.name, input2.name],
115124
[output.name],
116125
None,
117126
)

backends/arm/operators/op_mul.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from executorch.backends.arm.tosa_mapping import TosaArg
2626
from executorch.backends.arm.tosa_specification import TosaSpecification
27+
from executorch.backends.arm.tosa_utils import reshape_for_broadcast
2728
from serializer.tosa_serializer import TosaOp
2829

2930

@@ -43,6 +44,12 @@ def define_node(
4344
output: TosaArg,
4445
) -> None:
4546
assert inputs[0].dtype == inputs[1].dtype == output.dtype == ts.DType.INT8
47+
48+
dim_order = (
49+
inputs[0].dim_order
50+
if len(inputs[0].shape) > len(inputs[1].shape)
51+
else inputs[1].dim_order
52+
)
4653
input_A = inputs[0]
4754
input_B = inputs[1]
4855
input_qparams = get_input_qparams(node) # pyre-ignore[16]
@@ -68,15 +75,21 @@ def define_node(
6875
output_shape = tutils.tosa_shape(output.shape, output.dim_order)
6976
mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32)
7077

78+
input1, input2 = tutils.reshape_for_broadcast(
79+
tosa_graph,
80+
[
81+
input_A_rescaled,
82+
input_B_rescaled,
83+
],
84+
dim_order,
85+
)
86+
7187
# Do the INT32 Mul
7288
attr = ts.TosaSerializerAttribute()
7389
attr.MulAttribute(shift=0)
7490
tosa_graph.addOperator(
7591
TosaOp.Op().MUL,
76-
[
77-
input_A_rescaled.name,
78-
input_B_rescaled.name,
79-
],
92+
[input1.name, input2.name],
8093
[mul_output.name],
8194
attr,
8295
)
@@ -101,8 +114,11 @@ def define_node(
101114
) -> None:
102115
if inputs[0].dtype == ts.DType.INT8:
103116
return super().define_node(node, tosa_graph, inputs, output)
117+
118+
input1, input2 = reshape_for_broadcast(tosa_graph, inputs)
119+
104120
attr = ts.TosaSerializerAttribute()
105121
attr.MulAttribute(shift=0)
106122
tosa_graph.addOperator(
107-
TosaOp.Op().MUL, [inputs[0].name, inputs[1].name], [output.name], attr
123+
TosaOp.Op().MUL, [input1.name, input2.name], [output.name], attr
108124
)

backends/arm/operators/op_slice.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,11 @@ def define_node(
3232
output: TosaArg,
3333
) -> None:
3434

35+
# See slice_copy_support.py
36+
assert len(inputs) == 4 or (len(inputs) == 5 and inputs[4].number == 1)
37+
3538
# aten.slice_copy supports slicing in 1d at a time.
36-
# The arguments are dimension of slicing, start index and end index.
37-
assert len(inputs) == 4
39+
# The arguments are the actual input, dimension of slicing, start index, end index and optinal step or stride.
3840
input_node, dim, start, end = inputs
3941

4042
# Translate and check parameters in Pytorch dim order.

backends/arm/test/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def pytest_configure(config):
3636
)
3737
pytest._test_options["corstone_fvp"] = True # type: ignore[attr-defined]
3838
pytest._test_options["fast_fvp"] = config.option.fast_fvp # type: ignore[attr-defined]
39+
pytest._test_options["llama_inputs"] = config.option.llama_inputs # type: ignore[attr-defined]
3940
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
4041

4142

@@ -47,6 +48,11 @@ def pytest_addoption(parser):
4748
parser.addoption("--arm_quantize_io", action="store_true", help="Deprecated.")
4849
parser.addoption("--arm_run_corstoneFVP", action="store_true")
4950
parser.addoption("--fast_fvp", action="store_true")
51+
parser.addoption(
52+
"--llama_inputs",
53+
nargs="+",
54+
help="List of two files. Firstly .pt file. Secondly .json",
55+
)
5056

5157

5258
def pytest_sessionstart(session):
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import logging
9+
10+
import os
11+
import sys
12+
import unittest
13+
14+
import torch
15+
16+
from executorch.backends.arm.test import common, conftest
17+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
18+
from executorch.examples.models.llama.export_llama_lib import (
19+
build_args_parser,
20+
get_llama_model,
21+
)
22+
23+
from executorch.exir import EdgeCompileConfig
24+
25+
# Add project dir to sys path to workaround importlib.import_module() conditions in model_factory.py
26+
this_files_dir = os.path.dirname(os.path.abspath(__file__))
27+
project_dir = os.path.abspath(os.path.join(this_files_dir, "../../../.."))
28+
sys.path.append(project_dir)
29+
30+
logger = logging.getLogger(__name__)
31+
logger.setLevel(logging.INFO)
32+
33+
34+
class TestLlama(unittest.TestCase):
35+
"""
36+
Test class of Llama models. Type of Llama model depends on command line parameters:
37+
--llama_inputs <path to .pt file> <path to json file>
38+
Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json
39+
"""
40+
41+
_edge_compile_config: EdgeCompileConfig = EdgeCompileConfig(
42+
_check_ir_validity=False,
43+
_skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend.
44+
)
45+
46+
def prepare_model(self):
47+
48+
checkpoint = None
49+
params_file = None
50+
if conftest.is_option_enabled("llama_inputs"):
51+
param_list = conftest.get_option("llama_inputs")
52+
assert (
53+
isinstance(param_list, list) and len(param_list) == 2
54+
), "invalid number of inputs for --llama_inputs"
55+
checkpoint = param_list[0]
56+
params_file = param_list[1]
57+
assert isinstance(checkpoint, str) and isinstance(
58+
params_file, str
59+
), "invalid input for --llama_inputs"
60+
else:
61+
logging.warning(
62+
"Skipping Llama test because of lack of input. To run use --llama_inputs <.pt> <.json>"
63+
)
64+
return
65+
66+
assert os.path.isfile(checkpoint) and os.path.isfile(
67+
params_file
68+
), "Invalid file paths"
69+
70+
# TODO: Enable key value cache
71+
args = [
72+
"--disable_dynamic_shape",
73+
"-c",
74+
checkpoint,
75+
"-p",
76+
params_file,
77+
"--model",
78+
"stories110m",
79+
]
80+
parser = build_args_parser()
81+
args = parser.parse_args(args)
82+
83+
llama_model, llama_inputs, llama_meta = get_llama_model(args)
84+
85+
# TODO: Remove workaround since attention mask should not be persistent,
86+
# it only works if input shape is always the same
87+
freqs_c = "freqs_cos"
88+
freqs_s = "freqs_sin"
89+
for i in range(llama_model.n_layers):
90+
val = llama_model.layers[i].attention.get_buffer("mask")
91+
llama_model.layers[i].attention.register_buffer(
92+
"mask", val, persistent=True
93+
)
94+
val = llama_model.layers[i].attention.rope.get_buffer(freqs_c)
95+
llama_model.layers[i].attention.rope.register_buffer(
96+
freqs_c, val, persistent=True
97+
)
98+
val = llama_model.layers[i].attention.rope.get_buffer(freqs_s)
99+
llama_model.layers[i].attention.rope.register_buffer(
100+
freqs_s, val, persistent=True
101+
)
102+
103+
return llama_model, llama_inputs, llama_meta
104+
105+
def test_llama_tosa_MI(self):
106+
llama_model, llama_inputs, llama_meta = self.prepare_model()
107+
108+
with torch.no_grad():
109+
(
110+
ArmTester(
111+
llama_model,
112+
example_inputs=llama_inputs,
113+
compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),
114+
constant_methods=llama_meta,
115+
)
116+
.export()
117+
.to_edge_transform_and_lower(
118+
edge_compile_config=self._edge_compile_config
119+
)
120+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 25})
121+
.to_executorch()
122+
.run_method_and_compare_outputs(
123+
inputs=llama_inputs, atol=1.8, rtol=0.01
124+
)
125+
)

backends/arm/test/ops/test_add.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
98
from typing import Tuple
109

1110
import torch
@@ -55,6 +54,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
5554
}
5655

5756

57+
class Add3(torch.nn.Module):
58+
def forward(self, x: torch.Tensor, y: torch.Tensor):
59+
return x + y
60+
61+
test_data: list[input_t2] = {
62+
"3d_randn_diff_rank": (torch.randn(1, 4, 5), torch.randn(4, 1)),
63+
"4d_randn_diff_rank": (torch.randn(1, 1, 4, 4), torch.randn(4, 1)),
64+
"4d_randn_diff_rank_2": (torch.randn(4, 1), torch.randn(1, 1, 4, 5)),
65+
}
66+
67+
5868
@common.parametrize("test_data", Add.test_data)
5969
def test_add_tosa_MI(test_data: input_t1):
6070
pipeline = TosaPipelineMI[input_t1](Add(), test_data, aten_op, exir_op)
@@ -107,6 +117,18 @@ def test_add2_tosa_MI(test_data: input_t2):
107117
pipeline.run()
108118

109119

120+
@common.parametrize("test_data", Add3.test_data)
121+
def test_add3_tosa_MI(test_data: input_t2):
122+
pipeline = TosaPipelineMI[input_t2](Add3(), test_data, aten_op, exir_op)
123+
pipeline.run()
124+
125+
126+
@common.parametrize("test_data", Add3.test_data)
127+
def test_add3_tosa_BI(test_data: input_t2):
128+
pipeline = TosaPipelineBI[input_t2](Add3(), test_data, aten_op, exir_op)
129+
pipeline.run()
130+
131+
110132
@common.parametrize("test_data", Add2.test_data)
111133
def test_add2_tosa_BI(test_data: input_t2):
112134
pipeline = TosaPipelineBI[input_t2](Add2(), test_data, aten_op, exir_op)

0 commit comments

Comments
 (0)