Skip to content

Commit ca1b887

Browse files
committed
Update
[ghstack-poisoned]
1 parent 8b11366 commit ca1b887

File tree

3 files changed

+314
-0
lines changed

3 files changed

+314
-0
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable, List, Union
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class MaskedFillModel(torch.nn.Module):
16+
def __init__(self, value: Union[float, int]):
17+
super().__init__()
18+
self.value = value
19+
20+
def forward(self, x, mask):
21+
return x.masked_fill(mask, self.value)
22+
23+
@operator_test
24+
class TestMaskedFill(OperatorTest):
25+
@dtype_test
26+
def test_masked_fill_dtype(self, dtype, tester_factory: Callable) -> None:
27+
# Test with different dtypes
28+
model = MaskedFillModel(value=0.0)
29+
self._test_op(
30+
model,
31+
(
32+
torch.rand(3, 4).to(dtype),
33+
torch.tensor([[True, False, True, False], [False, True, False, True], [True, True, False, False]]),
34+
),
35+
tester_factory
36+
)
37+
38+
def test_masked_fill_basic(self, tester_factory: Callable) -> None:
39+
# Basic test with default parameters
40+
# Fill with 0.0 where mask is True
41+
self._test_op(
42+
MaskedFillModel(value=0.0),
43+
(
44+
torch.randn(3, 4),
45+
torch.tensor([[True, False, True, False], [False, True, False, True], [True, True, False, False]]),
46+
),
47+
tester_factory
48+
)
49+
50+
def test_masked_fill_different_values(self, tester_factory: Callable) -> None:
51+
# Test with different fill values
52+
53+
# Fill with a positive value
54+
self._test_op(
55+
MaskedFillModel(value=5.0),
56+
(
57+
torch.randn(3, 4),
58+
torch.tensor([[True, False, True, False], [False, True, False, True], [True, True, False, False]]),
59+
),
60+
tester_factory
61+
)
62+
63+
# Fill with a negative value
64+
self._test_op(
65+
MaskedFillModel(value=-5.0),
66+
(
67+
torch.randn(3, 4),
68+
torch.tensor([[True, False, True, False], [False, True, False, True], [True, True, False, False]]),
69+
),
70+
tester_factory
71+
)
72+
73+
# Fill with an integer value
74+
self._test_op(
75+
MaskedFillModel(value=1),
76+
(
77+
torch.randn(3, 4),
78+
torch.tensor([[True, False, True, False], [False, True, False, True], [True, True, False, False]]),
79+
),
80+
tester_factory
81+
)
82+
83+
def test_masked_fill_different_shapes(self, tester_factory: Callable) -> None:
84+
# Test with tensors of different shapes
85+
86+
# 1D tensor
87+
self._test_op(
88+
MaskedFillModel(value=0.0),
89+
(
90+
torch.randn(5),
91+
torch.tensor([True, False, True, False, True]),
92+
),
93+
tester_factory
94+
)
95+
96+
# 3D tensor
97+
self._test_op(
98+
MaskedFillModel(value=0.0),
99+
(
100+
torch.randn(2, 3, 4),
101+
torch.tensor([
102+
[[True, False, True, False], [False, True, False, True], [True, True, False, False]],
103+
[[False, False, True, True], [True, False, True, False], [False, True, False, True]]
104+
]),
105+
),
106+
tester_factory
107+
)
108+
109+
def test_masked_fill_all_true(self, tester_factory: Callable) -> None:
110+
# Test with all mask values set to True
111+
self._test_op(
112+
MaskedFillModel(value=0.0),
113+
(
114+
torch.randn(3, 4),
115+
torch.ones(3, 4, dtype=torch.bool),
116+
),
117+
tester_factory
118+
)
119+
120+
def test_masked_fill_all_false(self, tester_factory: Callable) -> None:
121+
# Test with all mask values set to False
122+
self._test_op(
123+
MaskedFillModel(value=0.0),
124+
(
125+
torch.randn(3, 4),
126+
torch.zeros(3, 4, dtype=torch.bool),
127+
),
128+
tester_factory
129+
)
130+
131+
def test_masked_fill_broadcast(self, tester_factory: Callable) -> None:
132+
# Test with broadcasting mask
133+
# A 1D mask can be broadcast to a 2D tensor
134+
self._test_op(
135+
MaskedFillModel(value=0.0),
136+
(
137+
torch.randn(3, 4),
138+
torch.tensor([True, False, True, False]),
139+
),
140+
tester_factory
141+
)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable, List
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class PermuteModel(torch.nn.Module):
16+
def __init__(self, dims: List[int]):
17+
super().__init__()
18+
self.dims = dims
19+
20+
def forward(self, x):
21+
return x.permute(self.dims)
22+
23+
@operator_test
24+
class TestPermute(OperatorTest):
25+
@dtype_test
26+
def test_permute_dtype(self, dtype, tester_factory: Callable) -> None:
27+
# Test with different dtypes
28+
model = PermuteModel(dims=[1, 0])
29+
self._test_op(model, (torch.rand(3, 4).to(dtype),), tester_factory)
30+
31+
def test_permute_basic(self, tester_factory: Callable) -> None:
32+
# Basic test with default parameters
33+
# Permute a 2D tensor from [3, 4] to [4, 3]
34+
self._test_op(PermuteModel(dims=[1, 0]), (torch.randn(3, 4),), tester_factory)
35+
36+
def test_permute_3d(self, tester_factory: Callable) -> None:
37+
# Test permuting a 3D tensor
38+
39+
# Permute from [2, 3, 4] to [4, 2, 3]
40+
self._test_op(PermuteModel(dims=[2, 0, 1]), (torch.randn(2, 3, 4),), tester_factory)
41+
42+
# Permute from [2, 3, 4] to [3, 4, 2]
43+
self._test_op(PermuteModel(dims=[1, 2, 0]), (torch.randn(2, 3, 4),), tester_factory)
44+
45+
# Permute from [2, 3, 4] to [2, 4, 3]
46+
self._test_op(PermuteModel(dims=[0, 2, 1]), (torch.randn(2, 3, 4),), tester_factory)
47+
48+
def test_permute_4d(self, tester_factory: Callable) -> None:
49+
# Test permuting a 4D tensor
50+
51+
# Permute from [2, 3, 4, 5] to [5, 4, 3, 2]
52+
self._test_op(PermuteModel(dims=[3, 2, 1, 0]), (torch.randn(2, 3, 4, 5),), tester_factory)
53+
54+
# Permute from [2, 3, 4, 5] to [2, 4, 3, 5]
55+
self._test_op(PermuteModel(dims=[0, 2, 1, 3]), (torch.randn(2, 3, 4, 5),), tester_factory)
56+
57+
def test_permute_identity(self, tester_factory: Callable) -> None:
58+
# Test identity permutation (no change)
59+
60+
# 2D tensor
61+
self._test_op(PermuteModel(dims=[0, 1]), (torch.randn(3, 4),), tester_factory)
62+
63+
# 3D tensor
64+
self._test_op(PermuteModel(dims=[0, 1, 2]), (torch.randn(2, 3, 4),), tester_factory)
65+
66+
def test_permute_different_shapes(self, tester_factory: Callable) -> None:
67+
# Test with tensors of different shapes
68+
69+
# 1D tensor (no permutation possible)
70+
self._test_op(PermuteModel(dims=[0]), (torch.randn(5),), tester_factory)
71+
72+
# 5D tensor
73+
self._test_op(PermuteModel(dims=[4, 3, 2, 1, 0]), (torch.randn(2, 3, 4, 5, 6),), tester_factory)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable, List
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class TransposeModel(torch.nn.Module):
16+
def __init__(self, dim0: int, dim1: int):
17+
super().__init__()
18+
self.dim0 = dim0
19+
self.dim1 = dim1
20+
21+
def forward(self, x):
22+
return torch.transpose(x, self.dim0, self.dim1)
23+
24+
@operator_test
25+
class TestTranspose(OperatorTest):
26+
@dtype_test
27+
def test_transpose_dtype(self, dtype, tester_factory: Callable) -> None:
28+
# Test with different dtypes
29+
model = TransposeModel(dim0=0, dim1=1)
30+
self._test_op(model, (torch.rand(3, 4).to(dtype),), tester_factory)
31+
32+
def test_transpose_basic(self, tester_factory: Callable) -> None:
33+
# Basic test with default parameters
34+
# Transpose a 2D tensor from [3, 4] to [4, 3]
35+
self._test_op(TransposeModel(dim0=0, dim1=1), (torch.randn(3, 4),), tester_factory)
36+
37+
def test_transpose_3d(self, tester_factory: Callable) -> None:
38+
# Test transposing a 3D tensor
39+
40+
# Transpose dimensions 0 and 1
41+
# From [2, 3, 4] to [3, 2, 4]
42+
self._test_op(TransposeModel(dim0=0, dim1=1), (torch.randn(2, 3, 4),), tester_factory)
43+
44+
# Transpose dimensions 0 and 2
45+
# From [2, 3, 4] to [4, 3, 2]
46+
self._test_op(TransposeModel(dim0=0, dim1=2), (torch.randn(2, 3, 4),), tester_factory)
47+
48+
# Transpose dimensions 1 and 2
49+
# From [2, 3, 4] to [2, 4, 3]
50+
self._test_op(TransposeModel(dim0=1, dim1=2), (torch.randn(2, 3, 4),), tester_factory)
51+
52+
def test_transpose_4d(self, tester_factory: Callable) -> None:
53+
# Test transposing a 4D tensor
54+
55+
# Transpose dimensions 0 and 3
56+
# From [2, 3, 4, 5] to [5, 3, 4, 2]
57+
self._test_op(TransposeModel(dim0=0, dim1=3), (torch.randn(2, 3, 4, 5),), tester_factory)
58+
59+
# Transpose dimensions 1 and 2
60+
# From [2, 3, 4, 5] to [2, 4, 3, 5]
61+
self._test_op(TransposeModel(dim0=1, dim1=2), (torch.randn(2, 3, 4, 5),), tester_factory)
62+
63+
def test_transpose_identity(self, tester_factory: Callable) -> None:
64+
# Test identity transpose (same dimension, no change)
65+
66+
# 2D tensor
67+
self._test_op(TransposeModel(dim0=0, dim1=0), (torch.randn(3, 4),), tester_factory)
68+
self._test_op(TransposeModel(dim0=1, dim1=1), (torch.randn(3, 4),), tester_factory)
69+
70+
# 3D tensor
71+
self._test_op(TransposeModel(dim0=0, dim1=0), (torch.randn(2, 3, 4),), tester_factory)
72+
self._test_op(TransposeModel(dim0=1, dim1=1), (torch.randn(2, 3, 4),), tester_factory)
73+
self._test_op(TransposeModel(dim0=2, dim1=2), (torch.randn(2, 3, 4),), tester_factory)
74+
75+
def test_transpose_negative_dims(self, tester_factory: Callable) -> None:
76+
# Test with negative dimensions (counting from the end)
77+
78+
# 3D tensor
79+
# Transpose dimensions -3 and -1 (equivalent to 0 and 2)
80+
# From [2, 3, 4] to [4, 3, 2]
81+
self._test_op(TransposeModel(dim0=-3, dim1=-1), (torch.randn(2, 3, 4),), tester_factory)
82+
83+
# Transpose dimensions -2 and -1 (equivalent to 1 and 2)
84+
# From [2, 3, 4] to [2, 4, 3]
85+
self._test_op(TransposeModel(dim0=-2, dim1=-1), (torch.randn(2, 3, 4),), tester_factory)
86+
87+
def test_transpose_different_shapes(self, tester_factory: Callable) -> None:
88+
# Test with tensors of different shapes
89+
90+
# 2D tensor
91+
self._test_op(TransposeModel(dim0=0, dim1=1), (torch.randn(3, 4),), tester_factory)
92+
93+
# 3D tensor
94+
self._test_op(TransposeModel(dim0=0, dim1=2), (torch.randn(2, 3, 4),), tester_factory)
95+
96+
# 4D tensor
97+
self._test_op(TransposeModel(dim0=1, dim1=3), (torch.randn(2, 3, 4, 5),), tester_factory)
98+
99+
# 5D tensor
100+
self._test_op(TransposeModel(dim0=0, dim1=4), (torch.randn(2, 3, 4, 5, 6),), tester_factory)

0 commit comments

Comments
 (0)