Skip to content

Commit 384ba50

Browse files
committed
[Backend Tester] Add permute, transpose, and masked_fill tests
ghstack-source-id: d89d4c3 ghstack-comment-id: 3116316539 Pull-Request: #12850
1 parent f3a2081 commit 384ba50

File tree

3 files changed

+417
-0
lines changed

3 files changed

+417
-0
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import Union
10+
11+
import torch
12+
from executorch.backends.test.suite.flow import TestFlow
13+
14+
from executorch.backends.test.suite.operators import (
15+
dtype_test,
16+
operator_test,
17+
OperatorTest,
18+
)
19+
20+
21+
class MaskedFillModel(torch.nn.Module):
22+
def __init__(self, value: Union[float, int]):
23+
super().__init__()
24+
self.value = value
25+
26+
def forward(self, x, mask):
27+
return x.masked_fill(mask, self.value)
28+
29+
30+
@operator_test
31+
class MaskedFill(OperatorTest):
32+
@dtype_test
33+
def test_masked_fill_dtype(self, flow: TestFlow, dtype) -> None:
34+
self._test_op(
35+
MaskedFillModel(value=0.0),
36+
(
37+
torch.rand(3, 4).to(dtype),
38+
torch.tensor(
39+
[
40+
[True, False, True, False],
41+
[False, True, False, True],
42+
[True, True, False, False],
43+
]
44+
),
45+
),
46+
flow,
47+
)
48+
49+
def test_masked_fill_basic(self, flow: TestFlow) -> None:
50+
self._test_op(
51+
MaskedFillModel(value=0.0),
52+
(
53+
torch.randn(3, 4),
54+
torch.tensor(
55+
[
56+
[True, False, True, False],
57+
[False, True, False, True],
58+
[True, True, False, False],
59+
]
60+
),
61+
),
62+
flow,
63+
)
64+
65+
def test_masked_fill_different_values(self, flow: TestFlow) -> None:
66+
self._test_op(
67+
MaskedFillModel(value=5.0),
68+
(
69+
torch.randn(3, 4),
70+
torch.tensor(
71+
[
72+
[True, False, True, False],
73+
[False, True, False, True],
74+
[True, True, False, False],
75+
]
76+
),
77+
),
78+
flow,
79+
)
80+
81+
self._test_op(
82+
MaskedFillModel(value=-5.0),
83+
(
84+
torch.randn(3, 4),
85+
torch.tensor(
86+
[
87+
[True, False, True, False],
88+
[False, True, False, True],
89+
[True, True, False, False],
90+
]
91+
),
92+
),
93+
flow,
94+
)
95+
96+
self._test_op(
97+
MaskedFillModel(value=1),
98+
(
99+
torch.randn(3, 4),
100+
torch.tensor(
101+
[
102+
[True, False, True, False],
103+
[False, True, False, True],
104+
[True, True, False, False],
105+
]
106+
),
107+
),
108+
flow,
109+
)
110+
111+
def test_masked_fill_different_shapes(self, flow: TestFlow) -> None:
112+
self._test_op(
113+
MaskedFillModel(value=0.0),
114+
(
115+
torch.randn(5),
116+
torch.tensor([True, False, True, False, True]),
117+
),
118+
flow,
119+
)
120+
121+
self._test_op(
122+
MaskedFillModel(value=0.0),
123+
(
124+
torch.randn(2, 3, 4),
125+
torch.tensor(
126+
[
127+
[
128+
[True, False, True, False],
129+
[False, True, False, True],
130+
[True, True, False, False],
131+
],
132+
[
133+
[False, False, True, True],
134+
[True, False, True, False],
135+
[False, True, False, True],
136+
],
137+
]
138+
),
139+
),
140+
flow,
141+
)
142+
143+
def test_masked_fill_all_true(self, flow: TestFlow) -> None:
144+
self._test_op(
145+
MaskedFillModel(value=0.0),
146+
(
147+
torch.randn(3, 4),
148+
torch.ones(3, 4, dtype=torch.bool),
149+
),
150+
flow,
151+
)
152+
153+
def test_masked_fill_all_false(self, flow: TestFlow) -> None:
154+
self._test_op(
155+
MaskedFillModel(value=0.0),
156+
(
157+
torch.randn(3, 4),
158+
torch.zeros(3, 4, dtype=torch.bool),
159+
),
160+
flow,
161+
)
162+
163+
def test_masked_fill_broadcast(self, flow: TestFlow) -> None:
164+
self._test_op(
165+
MaskedFillModel(value=0.0),
166+
(
167+
torch.randn(3, 4),
168+
torch.tensor([True, False, True, False]),
169+
),
170+
flow,
171+
)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import List
10+
11+
import torch
12+
from executorch.backends.test.suite.flow import TestFlow
13+
14+
from executorch.backends.test.suite.operators import (
15+
dtype_test,
16+
operator_test,
17+
OperatorTest,
18+
)
19+
20+
21+
class PermuteModel(torch.nn.Module):
22+
def __init__(self, dims: List[int]):
23+
super().__init__()
24+
self.dims = dims
25+
26+
def forward(self, x):
27+
return x.permute(self.dims)
28+
29+
30+
@operator_test
31+
class Permute(OperatorTest):
32+
@dtype_test
33+
def test_permute_dtype(self, flow: TestFlow, dtype) -> None:
34+
self._test_op(
35+
PermuteModel(dims=[1, 0]),
36+
(torch.rand(3, 4).to(dtype),),
37+
flow,
38+
)
39+
40+
def test_permute_basic(self, flow: TestFlow) -> None:
41+
self._test_op(
42+
PermuteModel(dims=[1, 0]),
43+
(torch.randn(3, 4),),
44+
flow,
45+
)
46+
47+
def test_permute_3d(self, flow: TestFlow) -> None:
48+
self._test_op(
49+
PermuteModel(dims=[2, 0, 1]),
50+
(torch.randn(2, 3, 4),),
51+
flow,
52+
)
53+
54+
self._test_op(
55+
PermuteModel(dims=[1, 2, 0]),
56+
(torch.randn(2, 3, 4),),
57+
flow,
58+
)
59+
60+
self._test_op(
61+
PermuteModel(dims=[0, 2, 1]),
62+
(torch.randn(2, 3, 4),),
63+
flow,
64+
)
65+
66+
def test_permute_4d(self, flow: TestFlow) -> None:
67+
self._test_op(
68+
PermuteModel(dims=[3, 2, 1, 0]),
69+
(torch.randn(2, 3, 4, 5),),
70+
flow,
71+
)
72+
73+
self._test_op(
74+
PermuteModel(dims=[0, 2, 1, 3]),
75+
(torch.randn(2, 3, 4, 5),),
76+
flow,
77+
)
78+
79+
def test_permute_identity(self, flow: TestFlow) -> None:
80+
self._test_op(
81+
PermuteModel(dims=[0, 1]),
82+
(torch.randn(3, 4),),
83+
flow,
84+
)
85+
86+
self._test_op(
87+
PermuteModel(dims=[0, 1, 2]),
88+
(torch.randn(2, 3, 4),),
89+
flow,
90+
)
91+
92+
def test_permute_different_shapes(self, flow: TestFlow) -> None:
93+
self._test_op(
94+
PermuteModel(dims=[0]),
95+
(torch.randn(5),),
96+
flow,
97+
)
98+
99+
self._test_op(
100+
PermuteModel(dims=[4, 3, 2, 1, 0]),
101+
(torch.randn(2, 3, 4, 5, 6),),
102+
flow,
103+
)

0 commit comments

Comments
 (0)