Skip to content

Commit 2adbf61

Browse files
committed
Update
[ghstack-poisoned]
1 parent aacddf5 commit 2adbf61

File tree

2 files changed

+219
-0
lines changed

2 files changed

+219
-0
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable, List, Tuple
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class IndexPutModel(torch.nn.Module):
16+
def __init__(self, accumulate=False):
17+
super().__init__()
18+
self.accumulate = accumulate
19+
20+
def forward(self, x, indices, values):
21+
# Clone the input to avoid modifying it in-place
22+
result = x.clone()
23+
# Apply index_put_ and return the modified tensor
24+
result.index_put_(indices, values, self.accumulate)
25+
return result
26+
27+
@operator_test
28+
class TestIndexPut(OperatorTest):
29+
@dtype_test
30+
def test_index_put_dtype(self, dtype, tester_factory: Callable) -> None:
31+
# Test with different dtypes
32+
indices = (torch.tensor([0, 2]),)
33+
values = torch.tensor([10.0, 20.0]).to(dtype)
34+
model = IndexPutModel()
35+
self._test_op(model, ((torch.rand(5, 2) * 100).to(dtype), indices, values), tester_factory, use_random_test_inputs=False)
36+
37+
def test_index_put_basic(self, tester_factory: Callable) -> None:
38+
# Basic test with default parameters
39+
indices = (torch.tensor([0, 2]),)
40+
values = torch.tensor([10.0, 20.0])
41+
self._test_op(IndexPutModel(), (torch.randn(5, 2), indices, values), tester_factory, use_random_test_inputs=False)
42+
43+
def test_index_put_accumulate(self, tester_factory: Callable) -> None:
44+
# Test with accumulate=True and accumulate=False
45+
46+
# Without accumulation (replace values)
47+
indices = (torch.tensor([0, 2]),)
48+
values = torch.tensor([10.0, 20.0])
49+
self._test_op(IndexPutModel(accumulate=False),
50+
(torch.ones(5, 2), indices, values), tester_factory, use_random_test_inputs=False)
51+
52+
# With accumulation (add values)
53+
indices = (torch.tensor([0, 2]),)
54+
values = torch.tensor([10.0, 20.0])
55+
self._test_op(IndexPutModel(accumulate=True),
56+
(torch.ones(5, 2), indices, values), tester_factory, use_random_test_inputs=False)
57+
58+
def test_index_put_shapes(self, tester_factory: Callable) -> None:
59+
# Test with different tensor shapes
60+
61+
# 1D tensor
62+
indices = (torch.tensor([0, 2]),)
63+
values = torch.tensor([10.0, 20.0])
64+
self._test_op(IndexPutModel(),
65+
(torch.randn(5), indices, values), tester_factory, use_random_test_inputs=False)
66+
67+
# 2D tensor
68+
indices = (torch.tensor([0, 2]), torch.tensor([1, 1]))
69+
values = torch.tensor([10.0, 20.0])
70+
self._test_op(IndexPutModel(),
71+
(torch.randn(5, 2), indices, values), tester_factory, use_random_test_inputs=False)
72+
73+
# 3D tensor
74+
indices = (torch.tensor([0, 2]), torch.tensor([1, 1]), torch.tensor([0, 1]))
75+
values = torch.tensor([10.0, 20.0])
76+
self._test_op(IndexPutModel(),
77+
(torch.randn(5, 3, 2), indices, values), tester_factory, use_random_test_inputs=False)
78+
79+
# 4D tensor
80+
indices = (torch.tensor([0, 2]), torch.tensor([1, 1]),
81+
torch.tensor([0, 1]), torch.tensor([2, 3]))
82+
values = torch.tensor([10.0,])
83+
self._test_op(IndexPutModel(),
84+
(torch.randn(5, 3, 2, 4), indices, values), tester_factory, use_random_test_inputs=False)
85+
86+
def test_index_put_indices(self, tester_factory: Callable) -> None:
87+
# Test with different index patterns
88+
89+
# Single index
90+
indices = (torch.tensor([2]),)
91+
values = torch.tensor([10.0])
92+
self._test_op(IndexPutModel(),
93+
(torch.randn(5, 2), indices, values), tester_factory, use_random_test_inputs=False)
94+
95+
# Multiple indices
96+
indices = (torch.tensor([0, 2, 4]),)
97+
values = torch.tensor([10.0, 20.0, 30.0])
98+
self._test_op(IndexPutModel(),
99+
(torch.randn(5, 3), indices, values), tester_factory, use_random_test_inputs=False)
100+
101+
# Repeated indices with accumulate=True (values add up)
102+
indices = (torch.tensor([1, 1, 3, 3]),)
103+
values = torch.tensor([10.0, 20.0, 30.0, 40.0])
104+
self._test_op(IndexPutModel(accumulate=True),
105+
(torch.randn(5), indices, values), tester_factory, use_random_test_inputs=False)
106+
107+
def test_index_put_edge_cases(self, tester_factory: Callable) -> None:
108+
# Test edge cases
109+
110+
# Put values in all positions
111+
indices = (torch.tensor([0, 1, 2, 3, 4]),)
112+
values = torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0])
113+
self._test_op(IndexPutModel(),
114+
(torch.randn(5, 5), indices, values), tester_factory, use_random_test_inputs=False)
115+
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2+
3+
# pyre-strict
4+
5+
from typing import Callable
6+
7+
import torch
8+
9+
from executorch.backends.test.compliance_suite import (
10+
dtype_test,
11+
operator_test,
12+
OperatorTest,
13+
)
14+
15+
class IndexSelectModel(torch.nn.Module):
16+
def __init__(self, dim=0):
17+
super().__init__()
18+
self.dim = dim
19+
20+
def forward(self, x, indices):
21+
return torch.index_select(x, self.dim, indices)
22+
23+
@operator_test
24+
class TestIndexSelect(OperatorTest):
25+
@dtype_test
26+
def test_index_select_dtype(self, dtype, tester_factory: Callable) -> None:
27+
# Test with different dtypes
28+
indices = torch.tensor([0, 2], dtype=torch.int64)
29+
model = IndexSelectModel(dim=0)
30+
self._test_op(model, ((torch.rand(5, 3) * 100).to(dtype), indices), tester_factory, use_random_test_inputs=False)
31+
32+
def test_index_select_basic(self, tester_factory: Callable) -> None:
33+
# Basic test with default parameters
34+
indices = torch.tensor([0, 2], dtype=torch.int64)
35+
self._test_op(IndexSelectModel(dim=0), (torch.randn(5, 3), indices), tester_factory, use_random_test_inputs=False)
36+
37+
def test_index_select_dimensions(self, tester_factory: Callable) -> None:
38+
# Test selecting along different dimensions
39+
40+
# Select along dim 0
41+
indices = torch.tensor([0, 2], dtype=torch.int64)
42+
self._test_op(IndexSelectModel(dim=0), (torch.randn(5, 3), indices), tester_factory, use_random_test_inputs=False)
43+
44+
# Select along dim 1
45+
indices = torch.tensor([0, 1], dtype=torch.int64)
46+
self._test_op(IndexSelectModel(dim=1), (torch.randn(5, 3), indices), tester_factory, use_random_test_inputs=False)
47+
48+
# Select along dim 2 in a 3D tensor
49+
indices = torch.tensor([0, 2], dtype=torch.int64)
50+
self._test_op(IndexSelectModel(dim=2), (torch.randn(3, 4, 5), indices), tester_factory, use_random_test_inputs=False)
51+
52+
def test_index_select_shapes(self, tester_factory: Callable) -> None:
53+
# Test with different tensor shapes
54+
indices = torch.tensor([0, 1], dtype=torch.int64)
55+
56+
# 1D tensor
57+
self._test_op(IndexSelectModel(dim=0), (torch.randn(5), indices), tester_factory, use_random_test_inputs=False)
58+
59+
# 2D tensor
60+
self._test_op(IndexSelectModel(dim=0), (torch.randn(5, 3), indices), tester_factory, use_random_test_inputs=False)
61+
62+
# 3D tensor
63+
self._test_op(IndexSelectModel(dim=0), (torch.randn(5, 3, 2), indices), tester_factory, use_random_test_inputs=False)
64+
65+
# 4D tensor
66+
self._test_op(IndexSelectModel(dim=0), (torch.randn(5, 3, 2, 4), indices), tester_factory, use_random_test_inputs=False)
67+
68+
def test_index_select_indices(self, tester_factory: Callable) -> None:
69+
# Test with different index patterns
70+
71+
# Single index
72+
indices = torch.tensor([2], dtype=torch.int64)
73+
self._test_op(IndexSelectModel(dim=0), (torch.randn(5, 3), indices), tester_factory, use_random_test_inputs=False)
74+
75+
# Multiple indices
76+
indices = torch.tensor([0, 2, 4], dtype=torch.int64)
77+
self._test_op(IndexSelectModel(dim=0), (torch.randn(5, 3), indices), tester_factory, use_random_test_inputs=False)
78+
79+
# Repeated indices
80+
indices = torch.tensor([1, 1, 3, 3], dtype=torch.int64)
81+
self._test_op(IndexSelectModel(dim=0), (torch.randn(5, 3), indices), tester_factory, use_random_test_inputs=False)
82+
83+
# Reversed indices
84+
indices = torch.tensor([4, 3, 2, 1, 0], dtype=torch.int64)
85+
self._test_op(IndexSelectModel(dim=0), (torch.randn(5, 3), indices), tester_factory, use_random_test_inputs=False)
86+
87+
def test_index_select_edge_cases(self, tester_factory: Callable) -> None:
88+
# Test edge cases
89+
90+
# Select all indices
91+
indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64)
92+
self._test_op(IndexSelectModel(dim=0), (torch.randn(5, 3), indices), tester_factory, use_random_test_inputs=False)
93+
94+
# Select from a dimension with size 1
95+
indices = torch.tensor([0], dtype=torch.int64)
96+
self._test_op(IndexSelectModel(dim=0), (torch.randn(1, 3), indices), tester_factory, use_random_test_inputs=False)
97+
98+
# Select from a tensor with all zeros
99+
indices = torch.tensor([0, 1], dtype=torch.int64)
100+
self._test_op(IndexSelectModel(dim=0), (torch.zeros(5, 3), indices), tester_factory, use_random_test_inputs=False)
101+
102+
# Select from a tensor with all ones
103+
indices = torch.tensor([0, 1], dtype=torch.int64)
104+
self._test_op(IndexSelectModel(dim=0), (torch.ones(5, 3), indices), tester_factory, use_random_test_inputs=False)

0 commit comments

Comments
 (0)