Skip to content

Commit 0145604

Browse files
authored
Arm Backend: Add tests for stack.default (#14623)
Stack is not in the list of core ATen ops and is decomposed automatically when lowering the graph (https://docs.pytorch.org/docs/main/export.html#export-ir-decompositions), so only the tests need to be added. stack is in this decomp table: https://github.com/pytorch/pytorch/blob/5d749ceb92c2c28bcfbdf918b4ab99b1a91fcb50/torch/_decomp/__init__.py#L466 Signed-off-by: Agrima Khare <[email protected]>
1 parent edf6927 commit 0145604

File tree

1 file changed

+150
-0
lines changed

1 file changed

+150
-0
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from typing import Tuple
7+
8+
import torch
9+
import torch.nn as nn
10+
11+
from executorch.backends.arm.test import common
12+
from executorch.backends.arm.test.tester.test_pipeline import (
13+
EthosU55PipelineINT,
14+
EthosU85PipelineINT,
15+
TosaPipelineFP,
16+
TosaPipelineINT,
17+
VgfPipeline,
18+
)
19+
20+
test_data_suite = {
21+
# (test_name, test_data)
22+
"ones_two_tensors": lambda: ((torch.ones(1), torch.ones(1)), 0),
23+
"ones_and_rand_three_tensors": lambda: (
24+
(torch.ones(1, 2), torch.randn(1, 2), torch.randn(1, 2)),
25+
1,
26+
),
27+
"ones_and_rand_four_tensors": lambda: (
28+
(
29+
torch.ones(1, 2, 5),
30+
torch.randn(1, 2, 5),
31+
torch.randn(1, 2, 5),
32+
torch.randn(1, 2, 5),
33+
),
34+
-1,
35+
),
36+
"rand_two_tensors": lambda: (
37+
(torch.randn(2, 2, 4), torch.randn(2, 2, 4)),
38+
2,
39+
),
40+
"rand_two_tensors_dim_0": lambda: (
41+
(torch.randn(1, 2, 4, 4), torch.randn(1, 2, 4, 4)),
42+
),
43+
"rand_two_tensors_dim_2": lambda: (
44+
(torch.randn(2, 2, 3, 5), torch.randn(2, 2, 3, 5)),
45+
2,
46+
),
47+
"rand_large": lambda: (
48+
(
49+
10000 * torch.randn(2, 3, 1, 4),
50+
torch.randn(2, 3, 1, 4),
51+
torch.randn(2, 3, 1, 4),
52+
),
53+
-3,
54+
),
55+
}
56+
57+
58+
class Stack(nn.Module):
59+
aten_op = "torch.ops.aten.stack.default"
60+
exir_op = "executorch_exir_dialects_edge__ops_aten_cat_default"
61+
62+
def forward(self, n: tuple[torch.Tensor, ...], dim: int = 0):
63+
return torch.stack(n, dim)
64+
65+
66+
input_t1 = Tuple[torch.Tensor]
67+
68+
69+
@common.parametrize("test_module", test_data_suite)
70+
def test_stack_tosa_FP(test_module: input_t1):
71+
test_data = test_module()
72+
pipeline = TosaPipelineFP[input_t1](
73+
Stack(),
74+
test_data,
75+
aten_op=Stack.aten_op,
76+
exir_op=Stack.exir_op,
77+
use_to_edge_transform_and_lower=False,
78+
)
79+
pipeline.run()
80+
81+
82+
@common.parametrize("test_module", test_data_suite)
83+
def test_stack_tosa_INT(test_module: input_t1):
84+
test_data = test_module()
85+
pipeline = TosaPipelineINT[input_t1](
86+
Stack(),
87+
test_data,
88+
aten_op=Stack.aten_op,
89+
exir_op=Stack.exir_op,
90+
use_to_edge_transform_and_lower=False,
91+
)
92+
pipeline.run()
93+
94+
95+
@common.XfailIfNoCorstone300
96+
@common.parametrize("test_module", test_data_suite)
97+
def test_stack_u55_INT(test_module: input_t1):
98+
test_data = test_module()
99+
pipeline = EthosU55PipelineINT[input_t1](
100+
Stack(),
101+
test_data,
102+
aten_ops=Stack.aten_op,
103+
exir_ops=Stack.exir_op,
104+
use_to_edge_transform_and_lower=False,
105+
)
106+
pipeline.run()
107+
108+
109+
@common.XfailIfNoCorstone320
110+
@common.parametrize("test_module", test_data_suite)
111+
def test_stack_u85_INT(test_module: input_t1):
112+
test_data = test_module()
113+
pipeline = EthosU85PipelineINT[input_t1](
114+
Stack(),
115+
test_data,
116+
aten_ops=Stack.aten_op,
117+
exir_ops=Stack.exir_op,
118+
use_to_edge_transform_and_lower=False,
119+
)
120+
pipeline.run()
121+
122+
123+
@common.SkipIfNoModelConverter
124+
@common.parametrize("test_module", test_data_suite)
125+
def test_stack_vgf_FP(test_module: input_t1):
126+
test_data = test_module()
127+
pipeline = VgfPipeline[input_t1](
128+
Stack(),
129+
test_data,
130+
aten_op=Stack.aten_op,
131+
exir_op=Stack.exir_op,
132+
tosa_version="TOSA-1.0+FP",
133+
use_to_edge_transform_and_lower=False,
134+
)
135+
pipeline.run()
136+
137+
138+
@common.SkipIfNoModelConverter
139+
@common.parametrize("test_module", test_data_suite)
140+
def test_stack_vgf_INT(test_module: input_t1):
141+
test_data = test_module()
142+
pipeline = VgfPipeline[input_t1](
143+
Stack(),
144+
test_data,
145+
aten_op=Stack.aten_op,
146+
exir_op=Stack.exir_op,
147+
tosa_version="TOSA-1.0+INT",
148+
use_to_edge_transform_and_lower=False,
149+
)
150+
pipeline.run()

0 commit comments

Comments
 (0)