Skip to content

Commit aacddf5

Browse files
committed
Update
[ghstack-poisoned]
1 parent ca1b887 commit aacddf5

File tree

10 files changed

+885
-0
lines changed

10 files changed

+885
-0
lines changed
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable, List
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class CatModel(torch.nn.Module):
16+
def __init__(self, dim: int = 0):
17+
super().__init__()
18+
self.dim = dim
19+
20+
def forward(self, x1, x2, x3):
21+
return torch.cat([x1, x2, x3], dim=self.dim)
22+
23+
@operator_test
24+
class TestCat(OperatorTest):
25+
@dtype_test
26+
def test_cat_dtype(self, dtype, tester_factory: Callable) -> None:
27+
# Test with different dtypes
28+
model = CatModel()
29+
self._test_op(
30+
model,
31+
(
32+
torch.rand(2, 3).to(dtype),
33+
torch.rand(3, 3).to(dtype),
34+
torch.rand(4, 3).to(dtype),
35+
),
36+
tester_factory
37+
)
38+
39+
def test_cat_basic(self, tester_factory: Callable) -> None:
40+
# Basic test with default parameters
41+
# Concatenate 3 tensors along dimension 0
42+
# Tensors of shapes [2, 3], [3, 3], [4, 3] -> Result will be of shape [9, 3]
43+
self._test_op(
44+
CatModel(),
45+
(
46+
torch.randn(2, 3),
47+
torch.randn(3, 3),
48+
torch.randn(4, 3),
49+
),
50+
tester_factory
51+
)
52+
53+
def test_cat_dimensions(self, tester_factory: Callable) -> None:
54+
# Test concatenating along different dimensions
55+
56+
# Concatenate along dimension 0 (default)
57+
# Tensors of shapes [2, 3], [3, 3], [4, 3] -> Result will be of shape [9, 3]
58+
self._test_op(
59+
CatModel(dim=0),
60+
(
61+
torch.randn(2, 3),
62+
torch.randn(3, 3),
63+
torch.randn(4, 3),
64+
),
65+
tester_factory
66+
)
67+
68+
# Concatenate along dimension 1
69+
# Tensors of shapes [3, 2], [3, 3], [3, 4] -> Result will be of shape [3, 9]
70+
self._test_op(
71+
CatModel(dim=1),
72+
(
73+
torch.randn(3, 2),
74+
torch.randn(3, 3),
75+
torch.randn(3, 4),
76+
),
77+
tester_factory
78+
)
79+
80+
# Concatenate along dimension 2
81+
# Tensors of shapes [2, 3, 1], [2, 3, 2], [2, 3, 3] -> Result will be of shape [2, 3, 6]
82+
self._test_op(
83+
CatModel(dim=2),
84+
(
85+
torch.randn(2, 3, 1),
86+
torch.randn(2, 3, 2),
87+
torch.randn(2, 3, 3),
88+
),
89+
tester_factory
90+
)
91+
92+
def test_cat_negative_dim(self, tester_factory: Callable) -> None:
93+
# Test with negative dimensions (counting from the end)
94+
95+
# Concatenate along the last dimension (dim=-1)
96+
# For tensors of shape [3, 2], [3, 3], [3, 4], this is equivalent to dim=1
97+
# Result will be of shape [3, 9]
98+
self._test_op(
99+
CatModel(dim=-1),
100+
(
101+
torch.randn(3, 2),
102+
torch.randn(3, 3),
103+
torch.randn(3, 4),
104+
),
105+
tester_factory
106+
)
107+
108+
# Concatenate along the second-to-last dimension (dim=-2)
109+
# For tensors of shape [2, 3], [3, 3], [4, 3], this is equivalent to dim=0
110+
# Result will be of shape [9, 3]
111+
self._test_op(
112+
CatModel(dim=-2),
113+
(
114+
torch.randn(2, 3),
115+
torch.randn(3, 3),
116+
torch.randn(4, 3),
117+
),
118+
tester_factory
119+
)
120+
121+
def test_cat_different_shapes(self, tester_factory: Callable) -> None:
122+
# Test with tensors of different shapes
123+
124+
# Concatenate 1D tensors
125+
# Tensors of shapes [2], [3], [4] -> Result will be of shape [9]
126+
self._test_op(
127+
CatModel(),
128+
(
129+
torch.randn(2),
130+
torch.randn(3),
131+
torch.randn(4),
132+
),
133+
tester_factory
134+
)
135+
136+
# Concatenate 3D tensors along dimension 0
137+
# Tensors of shapes [1, 3, 4], [2, 3, 4], [3, 3, 4] -> Result will be of shape [6, 3, 4]
138+
self._test_op(
139+
CatModel(dim=0),
140+
(
141+
torch.randn(1, 3, 4),
142+
torch.randn(2, 3, 4),
143+
torch.randn(3, 3, 4),
144+
),
145+
tester_factory
146+
)
147+
148+
# Concatenate 3D tensors along dimension 1
149+
# Tensors of shapes [2, 1, 4], [2, 2, 4], [2, 3, 4] -> Result will be of shape [2, 6, 4]
150+
self._test_op(
151+
CatModel(dim=1),
152+
(
153+
torch.randn(2, 1, 4),
154+
torch.randn(2, 2, 4),
155+
torch.randn(2, 3, 4),
156+
),
157+
tester_factory
158+
)
159+
160+
# Concatenate 3D tensors along dimension 2
161+
# Tensors of shapes [2, 3, 1], [2, 3, 2], [2, 3, 3] -> Result will be of shape [2, 3, 6]
162+
self._test_op(
163+
CatModel(dim=2),
164+
(
165+
torch.randn(2, 3, 1),
166+
torch.randn(2, 3, 2),
167+
torch.randn(2, 3, 3),
168+
),
169+
tester_factory
170+
)
171+
172+
def test_cat_same_shapes(self, tester_factory: Callable) -> None:
173+
# Test with tensors of the same shape
174+
# Tensors of shapes [2, 3], [2, 3], [2, 3] -> Result will be of shape [6, 3]
175+
self._test_op(
176+
CatModel(),
177+
(
178+
torch.randn(2, 3),
179+
torch.randn(2, 3),
180+
torch.randn(2, 3),
181+
),
182+
tester_factory
183+
)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable, List
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class ExpandModel(torch.nn.Module):
16+
def __init__(self, shape: List[int]):
17+
super().__init__()
18+
self.shape = shape
19+
20+
def forward(self, x):
21+
return x.expand(self.shape)
22+
23+
@operator_test
24+
class TestExpand(OperatorTest):
25+
@dtype_test
26+
def test_expand_dtype(self, dtype, tester_factory: Callable) -> None:
27+
# Test with different dtypes
28+
model = ExpandModel(shape=[3, 5])
29+
self._test_op(model, (torch.rand(1, 5).to(dtype),), tester_factory)
30+
31+
def test_expand_basic(self, tester_factory: Callable) -> None:
32+
# Basic test with default parameters
33+
# Expand from [1, 5] to [3, 5]
34+
self._test_op(ExpandModel(shape=[3, 5]), (torch.randn(1, 5),), tester_factory)
35+
36+
def test_expand_dimensions(self, tester_factory: Callable) -> None:
37+
# Test expanding different dimensions
38+
39+
# Expand first dimension
40+
self._test_op(ExpandModel(shape=[3, 5]), (torch.randn(1, 5),), tester_factory)
41+
42+
# Expand multiple dimensions
43+
self._test_op(ExpandModel(shape=[3, 4]), (torch.randn(1, 1),), tester_factory)
44+
45+
# Expand with adding a new dimension at the beginning
46+
self._test_op(ExpandModel(shape=[2, 1, 5]), (torch.randn(1, 5),), tester_factory)
47+
48+
# Expand with adding a new dimension in the middle
49+
self._test_op(ExpandModel(shape=[3, 2, 5]), (torch.randn(3, 1, 5),), tester_factory)
50+
51+
# Expand with adding a new dimension at the end
52+
self._test_op(ExpandModel(shape=[3, 5, 2]), (torch.randn(3, 5, 1),), tester_factory)
53+
54+
def test_expand_keep_original_size(self, tester_factory: Callable) -> None:
55+
# Test with -1 to keep the original size
56+
57+
# Keep the last dimension size
58+
self._test_op(ExpandModel(shape=[3, -1]), (torch.randn(1, 5),), tester_factory)
59+
60+
# Keep the first dimension size
61+
self._test_op(ExpandModel(shape=[-1, 5]), (torch.randn(2, 1),), tester_factory)
62+
63+
# Keep multiple dimension sizes
64+
self._test_op(ExpandModel(shape=[-1, 4, -1]), (torch.randn(2, 1, 3),), tester_factory)
65+
66+
def test_expand_singleton_dimensions(self, tester_factory: Callable) -> None:
67+
# Test expanding singleton dimensions
68+
69+
# Expand a scalar to a vector
70+
self._test_op(ExpandModel(shape=[5]), (torch.randn(1),), tester_factory)
71+
72+
# Expand a scalar to a matrix
73+
self._test_op(ExpandModel(shape=[3, 4]), (torch.randn(1, 1),), tester_factory)
74+
75+
# Expand a vector to a matrix by adding a dimension
76+
self._test_op(ExpandModel(shape=[3, 5]), (torch.randn(5),), tester_factory)
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable, List
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class ReshapeModel(torch.nn.Module):
16+
def __init__(self, shape: List[int]):
17+
super().__init__()
18+
self.shape = shape
19+
20+
def forward(self, x):
21+
return torch.reshape(x, self.shape)
22+
23+
@operator_test
24+
class TestReshape(OperatorTest):
25+
@dtype_test
26+
def test_reshape_dtype(self, dtype, tester_factory: Callable) -> None:
27+
# Test with different dtypes
28+
model = ReshapeModel(shape=[3, 5])
29+
self._test_op(model, (torch.rand(15).to(dtype),), tester_factory)
30+
31+
def test_reshape_basic(self, tester_factory: Callable) -> None:
32+
# Basic test with default parameters
33+
# Reshape from [15] to [3, 5]
34+
self._test_op(ReshapeModel(shape=[3, 5]), (torch.randn(15),), tester_factory)
35+
36+
def test_reshape_dimensions(self, tester_factory: Callable) -> None:
37+
# Test reshaping to different dimensions
38+
39+
# Reshape from 1D to 2D
40+
self._test_op(ReshapeModel(shape=[3, 5]), (torch.randn(15),), tester_factory)
41+
42+
# Reshape from 2D to 1D
43+
self._test_op(ReshapeModel(shape=[20]), (torch.randn(4, 5),), tester_factory)
44+
45+
# Reshape from 2D to 3D
46+
self._test_op(ReshapeModel(shape=[2, 2, 5]), (torch.randn(4, 5),), tester_factory)
47+
48+
# Reshape from 3D to 2D
49+
self._test_op(ReshapeModel(shape=[6, 4]), (torch.randn(3, 2, 4),), tester_factory)
50+
51+
def test_reshape_inferred_dimension(self, tester_factory: Callable) -> None:
52+
# Test with inferred dimension (-1)
53+
54+
# Infer the last dimension
55+
self._test_op(ReshapeModel(shape=[3, -1]), (torch.randn(15),), tester_factory)
56+
57+
# Infer the first dimension
58+
self._test_op(ReshapeModel(shape=[-1, 5]), (torch.randn(15),), tester_factory)
59+
60+
# Infer the middle dimension
61+
self._test_op(ReshapeModel(shape=[2, -1, 3]), (torch.randn(24),), tester_factory)
62+

0 commit comments

Comments
 (0)