Skip to content

Commit 64d8223

Browse files
authored
[Backend Tester] Add permute, transpose, and masked_fill tests (#12850)
Add tests for permute, transpose, and masked_fill.
1 parent 154a259 commit 64d8223

File tree

3 files changed

+353
-0
lines changed

3 files changed

+353
-0
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import Union
10+
11+
import torch
12+
from executorch.backends.test.suite.flow import TestFlow
13+
14+
from executorch.backends.test.suite.operators import (
15+
dtype_test,
16+
operator_test,
17+
OperatorTest,
18+
)
19+
20+
21+
class MaskedFillModel(torch.nn.Module):
22+
def __init__(self, value: Union[float, int]):
23+
super().__init__()
24+
self.value = value
25+
26+
def forward(self, x, mask):
27+
return x.masked_fill(mask, self.value)
28+
29+
30+
@operator_test
31+
class MaskedFill(OperatorTest):
32+
@dtype_test
33+
def test_masked_fill_dtype(self, flow: TestFlow, dtype) -> None:
34+
mask = torch.randint(0, 2, (16, 32), dtype=torch.bool)
35+
self._test_op(
36+
MaskedFillModel(value=0.0),
37+
(
38+
torch.rand(16, 32).to(dtype),
39+
mask,
40+
),
41+
flow,
42+
)
43+
44+
def test_masked_fill_different_values(self, flow: TestFlow) -> None:
45+
mask = torch.randint(0, 2, (16, 32), dtype=torch.bool)
46+
47+
self._test_op(
48+
MaskedFillModel(value=5.0),
49+
(
50+
torch.randn(16, 32),
51+
mask,
52+
),
53+
flow,
54+
)
55+
56+
self._test_op(
57+
MaskedFillModel(value=-5.0),
58+
(
59+
torch.randn(16, 32),
60+
mask,
61+
),
62+
flow,
63+
)
64+
65+
self._test_op(
66+
MaskedFillModel(value=1),
67+
(
68+
torch.randn(16, 32),
69+
mask,
70+
),
71+
flow,
72+
)
73+
74+
def test_masked_fill_different_shapes(self, flow: TestFlow) -> None:
75+
self._test_op(
76+
MaskedFillModel(value=0.0),
77+
(
78+
torch.randn(512),
79+
torch.randint(0, 2, (512,), dtype=torch.bool),
80+
),
81+
flow,
82+
)
83+
84+
self._test_op(
85+
MaskedFillModel(value=0.0),
86+
(
87+
torch.randn(4, 8, 16),
88+
torch.randint(0, 2, (4, 8, 16), dtype=torch.bool),
89+
),
90+
flow,
91+
)
92+
93+
def test_masked_fill_broadcast(self, flow: TestFlow) -> None:
94+
self._test_op(
95+
MaskedFillModel(value=0.0),
96+
(
97+
torch.randn(16, 32),
98+
torch.randint(0, 2, (32,), dtype=torch.bool),
99+
),
100+
flow,
101+
)
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
from typing import List
10+
11+
import torch
12+
from executorch.backends.test.suite.flow import TestFlow
13+
14+
from executorch.backends.test.suite.operators import (
15+
dtype_test,
16+
operator_test,
17+
OperatorTest,
18+
)
19+
20+
21+
class PermuteModel(torch.nn.Module):
22+
def __init__(self, dims: List[int]):
23+
super().__init__()
24+
self.dims = dims
25+
26+
def forward(self, x):
27+
return x.permute(self.dims)
28+
29+
30+
@operator_test
31+
class Permute(OperatorTest):
32+
@dtype_test
33+
def test_permute_dtype(self, flow: TestFlow, dtype) -> None:
34+
self._test_op(
35+
PermuteModel(dims=[1, 0]),
36+
(torch.rand(20, 32).to(dtype),),
37+
flow,
38+
)
39+
40+
def test_permute_3d(self, flow: TestFlow) -> None:
41+
self._test_op(
42+
PermuteModel(dims=[2, 0, 1]),
43+
(torch.randn(8, 10, 12),),
44+
flow,
45+
)
46+
47+
self._test_op(
48+
PermuteModel(dims=[1, 2, 0]),
49+
(torch.randn(8, 10, 12),),
50+
flow,
51+
)
52+
53+
self._test_op(
54+
PermuteModel(dims=[0, 2, 1]),
55+
(torch.randn(8, 10, 12),),
56+
flow,
57+
)
58+
59+
def test_permute_4d(self, flow: TestFlow) -> None:
60+
self._test_op(
61+
PermuteModel(dims=[3, 2, 1, 0]),
62+
(torch.randn(4, 6, 8, 10),),
63+
flow,
64+
)
65+
66+
self._test_op(
67+
PermuteModel(dims=[0, 2, 1, 3]),
68+
(torch.randn(4, 6, 8, 10),),
69+
flow,
70+
)
71+
72+
def test_permute_identity(self, flow: TestFlow) -> None:
73+
self._test_op(
74+
PermuteModel(dims=[0, 1]),
75+
(torch.randn(20, 32),),
76+
flow,
77+
)
78+
79+
self._test_op(
80+
PermuteModel(dims=[0, 1, 2]),
81+
(torch.randn(8, 10, 12),),
82+
flow,
83+
)
84+
85+
def test_permute_negative_dims(self, flow: TestFlow) -> None:
86+
self._test_op(
87+
PermuteModel(dims=[-1, -3, -2, -4]),
88+
(torch.randn(4, 6, 8, 10),),
89+
flow,
90+
)
91+
92+
self._test_op(
93+
PermuteModel(dims=[-4, -2, -3, -1]),
94+
(torch.randn(4, 6, 8, 10),),
95+
flow,
96+
)
97+
98+
def test_permute_different_shapes(self, flow: TestFlow) -> None:
99+
self._test_op(
100+
PermuteModel(dims=[0]),
101+
(torch.randn(512),),
102+
flow,
103+
)
104+
105+
self._test_op(
106+
PermuteModel(dims=[4, 3, 2, 1, 0]),
107+
(torch.randn(2, 3, 4, 5, 6),),
108+
flow,
109+
)
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
import torch
11+
from executorch.backends.test.suite.flow import TestFlow
12+
13+
from executorch.backends.test.suite.operators import (
14+
dtype_test,
15+
operator_test,
16+
OperatorTest,
17+
)
18+
19+
20+
class TransposeModel(torch.nn.Module):
21+
def __init__(self, dim0: int, dim1: int):
22+
super().__init__()
23+
self.dim0 = dim0
24+
self.dim1 = dim1
25+
26+
def forward(self, x):
27+
return torch.transpose(x, self.dim0, self.dim1)
28+
29+
30+
@operator_test
31+
class Transpose(OperatorTest):
32+
@dtype_test
33+
def test_transpose_dtype(self, flow: TestFlow, dtype) -> None:
34+
self._test_op(
35+
TransposeModel(dim0=0, dim1=1),
36+
(torch.rand(20, 32).to(dtype),),
37+
flow,
38+
)
39+
40+
def test_transpose_basic(self, flow: TestFlow) -> None:
41+
self._test_op(
42+
TransposeModel(dim0=0, dim1=1),
43+
(torch.randn(20, 32),),
44+
flow,
45+
)
46+
47+
def test_transpose_3d(self, flow: TestFlow) -> None:
48+
self._test_op(
49+
TransposeModel(dim0=0, dim1=1),
50+
(torch.randn(8, 10, 12),),
51+
flow,
52+
)
53+
54+
self._test_op(
55+
TransposeModel(dim0=0, dim1=2),
56+
(torch.randn(8, 10, 12),),
57+
flow,
58+
)
59+
60+
self._test_op(
61+
TransposeModel(dim0=1, dim1=2),
62+
(torch.randn(8, 10, 12),),
63+
flow,
64+
)
65+
66+
def test_transpose_4d(self, flow: TestFlow) -> None:
67+
self._test_op(
68+
TransposeModel(dim0=0, dim1=3),
69+
(torch.randn(4, 6, 8, 10),),
70+
flow,
71+
)
72+
73+
self._test_op(
74+
TransposeModel(dim0=1, dim1=2),
75+
(torch.randn(4, 6, 8, 10),),
76+
flow,
77+
)
78+
79+
def test_transpose_identity(self, flow: TestFlow) -> None:
80+
self._test_op(
81+
TransposeModel(dim0=0, dim1=0),
82+
(torch.randn(20, 32),),
83+
flow,
84+
)
85+
self._test_op(
86+
TransposeModel(dim0=1, dim1=1),
87+
(torch.randn(20, 32),),
88+
flow,
89+
)
90+
91+
self._test_op(
92+
TransposeModel(dim0=0, dim1=0),
93+
(torch.randn(8, 10, 12),),
94+
flow,
95+
)
96+
self._test_op(
97+
TransposeModel(dim0=1, dim1=1),
98+
(torch.randn(8, 10, 12),),
99+
flow,
100+
)
101+
self._test_op(
102+
TransposeModel(dim0=2, dim1=2),
103+
(torch.randn(8, 10, 12),),
104+
flow,
105+
)
106+
107+
def test_transpose_negative_dims(self, flow: TestFlow) -> None:
108+
self._test_op(
109+
TransposeModel(dim0=-3, dim1=-1),
110+
(torch.randn(8, 10, 12),),
111+
flow,
112+
)
113+
114+
self._test_op(
115+
TransposeModel(dim0=-2, dim1=-1),
116+
(torch.randn(8, 10, 12),),
117+
flow,
118+
)
119+
120+
def test_transpose_different_shapes(self, flow: TestFlow) -> None:
121+
self._test_op(
122+
TransposeModel(dim0=0, dim1=1),
123+
(torch.randn(20, 32),),
124+
flow,
125+
)
126+
127+
self._test_op(
128+
TransposeModel(dim0=0, dim1=2),
129+
(torch.randn(8, 10, 12),),
130+
flow,
131+
)
132+
133+
self._test_op(
134+
TransposeModel(dim0=1, dim1=3),
135+
(torch.randn(4, 6, 8, 10),),
136+
flow,
137+
)
138+
139+
self._test_op(
140+
TransposeModel(dim0=0, dim1=4),
141+
(torch.randn(2, 3, 4, 5, 6),),
142+
flow,
143+
)

0 commit comments

Comments
 (0)