Skip to content

Commit 04ee2df

Browse files
committed
[Backend Tester] Add slice and reshape tests
ghstack-source-id: 0553f1e ghstack-comment-id: 3116316656 Pull-Request: #12851
1 parent 384ba50 commit 04ee2df

File tree

10 files changed

+1072
-0
lines changed

10 files changed

+1072
-0
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
import torch
11+
from executorch.backends.test.suite.flow import TestFlow
12+
13+
from executorch.backends.test.suite.operators import (
14+
dtype_test,
15+
operator_test,
16+
OperatorTest,
17+
)
18+
19+
20+
class CatModel(torch.nn.Module):
21+
def __init__(self, dim: int = 0):
22+
super().__init__()
23+
self.dim = dim
24+
25+
def forward(self, x1, x2, x3):
26+
return torch.cat([x1, x2, x3], dim=self.dim)
27+
28+
29+
@operator_test
30+
class Cat(OperatorTest):
31+
@dtype_test
32+
def test_cat_dtype(self, flow: TestFlow, dtype) -> None:
33+
self._test_op(
34+
CatModel(),
35+
(
36+
torch.rand(2, 3).to(dtype),
37+
torch.rand(3, 3).to(dtype),
38+
torch.rand(4, 3).to(dtype),
39+
),
40+
flow,
41+
)
42+
43+
def test_cat_basic(self, flow: TestFlow) -> None:
44+
self._test_op(
45+
CatModel(),
46+
(
47+
torch.randn(2, 3),
48+
torch.randn(3, 3),
49+
torch.randn(4, 3),
50+
),
51+
flow,
52+
)
53+
54+
def test_cat_dimensions(self, flow: TestFlow) -> None:
55+
self._test_op(
56+
CatModel(dim=0),
57+
(
58+
torch.randn(2, 3),
59+
torch.randn(3, 3),
60+
torch.randn(4, 3),
61+
),
62+
flow,
63+
)
64+
65+
self._test_op(
66+
CatModel(dim=1),
67+
(
68+
torch.randn(3, 2),
69+
torch.randn(3, 3),
70+
torch.randn(3, 4),
71+
),
72+
flow,
73+
)
74+
75+
self._test_op(
76+
CatModel(dim=2),
77+
(
78+
torch.randn(2, 3, 1),
79+
torch.randn(2, 3, 2),
80+
torch.randn(2, 3, 3),
81+
),
82+
flow,
83+
)
84+
85+
def test_cat_negative_dim(self, flow: TestFlow) -> None:
86+
self._test_op(
87+
CatModel(dim=-1),
88+
(
89+
torch.randn(3, 2),
90+
torch.randn(3, 3),
91+
torch.randn(3, 4),
92+
),
93+
flow,
94+
)
95+
96+
self._test_op(
97+
CatModel(dim=-2),
98+
(
99+
torch.randn(2, 3),
100+
torch.randn(3, 3),
101+
torch.randn(4, 3),
102+
),
103+
flow,
104+
)
105+
106+
def test_cat_different_shapes(self, flow: TestFlow) -> None:
107+
self._test_op(
108+
CatModel(),
109+
(
110+
torch.randn(2),
111+
torch.randn(3),
112+
torch.randn(4),
113+
),
114+
flow,
115+
)
116+
117+
self._test_op(
118+
CatModel(dim=0),
119+
(
120+
torch.randn(1, 3, 4),
121+
torch.randn(2, 3, 4),
122+
torch.randn(3, 3, 4),
123+
),
124+
flow,
125+
)
126+
127+
self._test_op(
128+
CatModel(dim=1),
129+
(
130+
torch.randn(2, 1, 4),
131+
torch.randn(2, 2, 4),
132+
torch.randn(2, 3, 4),
133+
),
134+
flow,
135+
)
136+
137+
self._test_op(
138+
CatModel(dim=2),
139+
(
140+
torch.randn(2, 3, 1),
141+
torch.randn(2, 3, 2),
142+
torch.randn(2, 3, 3),
143+
),
144+
flow,
145+
)
146+
147+
def test_cat_same_shapes(self, flow: TestFlow) -> None:
148+
self._test_op(
149+
CatModel(),
150+
(
151+
torch.randn(2, 3),
152+
torch.randn(2, 3),
153+
torch.randn(2, 3),
154+
),
155+
flow,
156+
)
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import List
10+
11+
import torch
12+
from executorch.backends.test.suite.flow import TestFlow
13+
14+
from executorch.backends.test.suite.operators import (
15+
dtype_test,
16+
operator_test,
17+
OperatorTest,
18+
)
19+
20+
21+
class ExpandModel(torch.nn.Module):
22+
def __init__(self, shape: List[int]):
23+
super().__init__()
24+
self.shape = shape
25+
26+
def forward(self, x):
27+
return x.expand(self.shape)
28+
29+
30+
@operator_test
31+
class Expand(OperatorTest):
32+
@dtype_test
33+
def test_expand_dtype(self, flow: TestFlow, dtype) -> None:
34+
self._test_op(
35+
ExpandModel(shape=[3, 5]),
36+
(torch.rand(1, 5).to(dtype),),
37+
flow,
38+
)
39+
40+
def test_expand_basic(self, flow: TestFlow) -> None:
41+
self._test_op(
42+
ExpandModel(shape=[3, 5]),
43+
(torch.randn(1, 5),),
44+
flow,
45+
)
46+
47+
def test_expand_dimensions(self, flow: TestFlow) -> None:
48+
self._test_op(
49+
ExpandModel(shape=[3, 5]),
50+
(torch.randn(1, 5),),
51+
flow,
52+
)
53+
54+
self._test_op(
55+
ExpandModel(shape=[3, 4]),
56+
(torch.randn(1, 1),),
57+
flow,
58+
)
59+
60+
self._test_op(
61+
ExpandModel(shape=[2, 1, 5]),
62+
(torch.randn(1, 5),),
63+
flow,
64+
)
65+
66+
self._test_op(
67+
ExpandModel(shape=[3, 2, 5]),
68+
(torch.randn(3, 1, 5),),
69+
flow,
70+
)
71+
72+
self._test_op(
73+
ExpandModel(shape=[3, 5, 2]),
74+
(torch.randn(3, 5, 1),),
75+
flow,
76+
)
77+
78+
def test_expand_keep_original_size(self, flow: TestFlow) -> None:
79+
self._test_op(
80+
ExpandModel(shape=[3, -1]),
81+
(torch.randn(1, 5),),
82+
flow,
83+
)
84+
85+
self._test_op(
86+
ExpandModel(shape=[-1, 5]),
87+
(torch.randn(2, 1),),
88+
flow,
89+
)
90+
91+
self._test_op(
92+
ExpandModel(shape=[-1, 4, -1]),
93+
(torch.randn(2, 1, 3),),
94+
flow,
95+
)
96+
97+
def test_expand_singleton_dimensions(self, flow: TestFlow) -> None:
98+
self._test_op(
99+
ExpandModel(shape=[5]),
100+
(torch.randn(1),),
101+
flow,
102+
)
103+
104+
self._test_op(
105+
ExpandModel(shape=[3, 4]),
106+
(torch.randn(1, 1),),
107+
flow,
108+
)
109+
110+
self._test_op(
111+
ExpandModel(shape=[3, 5]),
112+
(torch.randn(5),),
113+
flow,
114+
)
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import List
10+
11+
import torch
12+
from executorch.backends.test.suite.flow import TestFlow
13+
14+
from executorch.backends.test.suite.operators import (
15+
dtype_test,
16+
operator_test,
17+
OperatorTest,
18+
)
19+
20+
21+
class ReshapeModel(torch.nn.Module):
22+
def __init__(self, shape: List[int]):
23+
super().__init__()
24+
self.shape = shape
25+
26+
def forward(self, x):
27+
return torch.reshape(x, self.shape)
28+
29+
30+
@operator_test
31+
class Reshape(OperatorTest):
32+
@dtype_test
33+
def test_reshape_dtype(self, flow: TestFlow, dtype) -> None:
34+
self._test_op(
35+
ReshapeModel(shape=[3, 5]),
36+
(torch.rand(15).to(dtype),),
37+
flow,
38+
)
39+
40+
def test_reshape_basic(self, flow: TestFlow) -> None:
41+
self._test_op(
42+
ReshapeModel(shape=[3, 5]),
43+
(torch.randn(15),),
44+
flow,
45+
)
46+
47+
def test_reshape_dimensions(self, flow: TestFlow) -> None:
48+
self._test_op(
49+
ReshapeModel(shape=[3, 5]),
50+
(torch.randn(15),),
51+
flow,
52+
)
53+
54+
self._test_op(
55+
ReshapeModel(shape=[20]),
56+
(torch.randn(4, 5),),
57+
flow,
58+
)
59+
60+
self._test_op(
61+
ReshapeModel(shape=[2, 2, 5]),
62+
(torch.randn(4, 5),),
63+
flow,
64+
)
65+
66+
self._test_op(
67+
ReshapeModel(shape=[6, 4]),
68+
(torch.randn(3, 2, 4),),
69+
flow,
70+
)
71+
72+
def test_reshape_inferred_dimension(self, flow: TestFlow) -> None:
73+
self._test_op(
74+
ReshapeModel(shape=[3, -1]),
75+
(torch.randn(15),),
76+
flow,
77+
)
78+
79+
self._test_op(
80+
ReshapeModel(shape=[-1, 5]),
81+
(torch.randn(15),),
82+
flow,
83+
)
84+
85+
self._test_op(
86+
ReshapeModel(shape=[2, -1, 3]),
87+
(torch.randn(24),),
88+
flow,
89+
)

0 commit comments

Comments
 (0)