Skip to content

Commit de4b52e

Browse files
committed
Update
[ghstack-poisoned]
2 parents aacddf5 + 346ab2e commit de4b52e

24 files changed

+2034
-1238
lines changed

backends/test/suite/operators/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,9 @@ def _create_test_for_backend(
133133

134134

135135
class OperatorTest(unittest.TestCase):
136-
def _test_op(self, model, inputs, flow: TestFlow):
136+
def _test_op(
137+
self, model, inputs, flow: TestFlow, generate_random_test_inputs: bool = True
138+
):
137139
context = get_active_test_context()
138140

139141
# This should be set in the wrapped test. See _make_wrapped_test above.
@@ -145,6 +147,7 @@ def _test_op(self, model, inputs, flow: TestFlow):
145147
flow,
146148
context.test_name,
147149
context.params,
150+
generate_random_test_inputs=generate_random_test_inputs,
148151
)
149152

150153
log_test_summary(run_summary)
Lines changed: 65 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,183 +1,156 @@
1-
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
26

3-
# pyre-strict
7+
# pyre-unsafe
48

5-
from typing import Callable, List
69

710
import torch
11+
from executorch.backends.test.suite.flow import TestFlow
812

9-
from executorch.backends.test.compliance_suite import (
13+
from executorch.backends.test.suite.operators import (
1014
dtype_test,
1115
operator_test,
1216
OperatorTest,
1317
)
1418

19+
1520
class CatModel(torch.nn.Module):
1621
def __init__(self, dim: int = 0):
1722
super().__init__()
1823
self.dim = dim
19-
24+
2025
def forward(self, x1, x2, x3):
2126
return torch.cat([x1, x2, x3], dim=self.dim)
2227

28+
2329
@operator_test
24-
class TestCat(OperatorTest):
30+
class Cat(OperatorTest):
2531
@dtype_test
26-
def test_cat_dtype(self, dtype, tester_factory: Callable) -> None:
27-
# Test with different dtypes
28-
model = CatModel()
32+
def test_cat_dtype(self, flow: TestFlow, dtype) -> None:
2933
self._test_op(
30-
model,
34+
CatModel(),
3135
(
3236
torch.rand(2, 3).to(dtype),
3337
torch.rand(3, 3).to(dtype),
3438
torch.rand(4, 3).to(dtype),
35-
),
36-
tester_factory
39+
),
40+
flow,
3741
)
38-
39-
def test_cat_basic(self, tester_factory: Callable) -> None:
40-
# Basic test with default parameters
41-
# Concatenate 3 tensors along dimension 0
42-
# Tensors of shapes [2, 3], [3, 3], [4, 3] -> Result will be of shape [9, 3]
42+
43+
def test_cat_basic(self, flow: TestFlow) -> None:
4344
self._test_op(
44-
CatModel(),
45+
CatModel(),
4546
(
4647
torch.randn(2, 3),
4748
torch.randn(3, 3),
4849
torch.randn(4, 3),
49-
),
50-
tester_factory
50+
),
51+
flow,
5152
)
52-
53-
def test_cat_dimensions(self, tester_factory: Callable) -> None:
54-
# Test concatenating along different dimensions
55-
56-
# Concatenate along dimension 0 (default)
57-
# Tensors of shapes [2, 3], [3, 3], [4, 3] -> Result will be of shape [9, 3]
53+
54+
def test_cat_dimensions(self, flow: TestFlow) -> None:
5855
self._test_op(
59-
CatModel(dim=0),
56+
CatModel(dim=0),
6057
(
6158
torch.randn(2, 3),
6259
torch.randn(3, 3),
6360
torch.randn(4, 3),
64-
),
65-
tester_factory
61+
),
62+
flow,
6663
)
67-
68-
# Concatenate along dimension 1
69-
# Tensors of shapes [3, 2], [3, 3], [3, 4] -> Result will be of shape [3, 9]
64+
7065
self._test_op(
71-
CatModel(dim=1),
66+
CatModel(dim=1),
7267
(
7368
torch.randn(3, 2),
7469
torch.randn(3, 3),
7570
torch.randn(3, 4),
76-
),
77-
tester_factory
71+
),
72+
flow,
7873
)
79-
80-
# Concatenate along dimension 2
81-
# Tensors of shapes [2, 3, 1], [2, 3, 2], [2, 3, 3] -> Result will be of shape [2, 3, 6]
74+
8275
self._test_op(
83-
CatModel(dim=2),
76+
CatModel(dim=2),
8477
(
8578
torch.randn(2, 3, 1),
8679
torch.randn(2, 3, 2),
8780
torch.randn(2, 3, 3),
88-
),
89-
tester_factory
81+
),
82+
flow,
9083
)
91-
92-
def test_cat_negative_dim(self, tester_factory: Callable) -> None:
93-
# Test with negative dimensions (counting from the end)
94-
95-
# Concatenate along the last dimension (dim=-1)
96-
# For tensors of shape [3, 2], [3, 3], [3, 4], this is equivalent to dim=1
97-
# Result will be of shape [3, 9]
84+
85+
def test_cat_negative_dim(self, flow: TestFlow) -> None:
9886
self._test_op(
99-
CatModel(dim=-1),
87+
CatModel(dim=-1),
10088
(
10189
torch.randn(3, 2),
10290
torch.randn(3, 3),
10391
torch.randn(3, 4),
104-
),
105-
tester_factory
92+
),
93+
flow,
10694
)
107-
108-
# Concatenate along the second-to-last dimension (dim=-2)
109-
# For tensors of shape [2, 3], [3, 3], [4, 3], this is equivalent to dim=0
110-
# Result will be of shape [9, 3]
95+
11196
self._test_op(
112-
CatModel(dim=-2),
97+
CatModel(dim=-2),
11398
(
11499
torch.randn(2, 3),
115100
torch.randn(3, 3),
116101
torch.randn(4, 3),
117-
),
118-
tester_factory
102+
),
103+
flow,
119104
)
120-
121-
def test_cat_different_shapes(self, tester_factory: Callable) -> None:
122-
# Test with tensors of different shapes
123-
124-
# Concatenate 1D tensors
125-
# Tensors of shapes [2], [3], [4] -> Result will be of shape [9]
105+
106+
def test_cat_different_shapes(self, flow: TestFlow) -> None:
126107
self._test_op(
127-
CatModel(),
108+
CatModel(),
128109
(
129110
torch.randn(2),
130111
torch.randn(3),
131112
torch.randn(4),
132-
),
133-
tester_factory
113+
),
114+
flow,
134115
)
135-
136-
# Concatenate 3D tensors along dimension 0
137-
# Tensors of shapes [1, 3, 4], [2, 3, 4], [3, 3, 4] -> Result will be of shape [6, 3, 4]
116+
138117
self._test_op(
139-
CatModel(dim=0),
118+
CatModel(dim=0),
140119
(
141120
torch.randn(1, 3, 4),
142121
torch.randn(2, 3, 4),
143122
torch.randn(3, 3, 4),
144-
),
145-
tester_factory
123+
),
124+
flow,
146125
)
147-
148-
# Concatenate 3D tensors along dimension 1
149-
# Tensors of shapes [2, 1, 4], [2, 2, 4], [2, 3, 4] -> Result will be of shape [2, 6, 4]
126+
150127
self._test_op(
151-
CatModel(dim=1),
128+
CatModel(dim=1),
152129
(
153130
torch.randn(2, 1, 4),
154131
torch.randn(2, 2, 4),
155132
torch.randn(2, 3, 4),
156-
),
157-
tester_factory
133+
),
134+
flow,
158135
)
159-
160-
# Concatenate 3D tensors along dimension 2
161-
# Tensors of shapes [2, 3, 1], [2, 3, 2], [2, 3, 3] -> Result will be of shape [2, 3, 6]
136+
162137
self._test_op(
163-
CatModel(dim=2),
138+
CatModel(dim=2),
164139
(
165140
torch.randn(2, 3, 1),
166141
torch.randn(2, 3, 2),
167142
torch.randn(2, 3, 3),
168-
),
169-
tester_factory
143+
),
144+
flow,
170145
)
171-
172-
def test_cat_same_shapes(self, tester_factory: Callable) -> None:
173-
# Test with tensors of the same shape
174-
# Tensors of shapes [2, 3], [2, 3], [2, 3] -> Result will be of shape [6, 3]
146+
147+
def test_cat_same_shapes(self, flow: TestFlow) -> None:
175148
self._test_op(
176-
CatModel(),
149+
CatModel(),
177150
(
178151
torch.randn(2, 3),
179152
torch.randn(2, 3),
180153
torch.randn(2, 3),
181-
),
182-
tester_factory
154+
),
155+
flow,
183156
)

0 commit comments

Comments
 (0)