Skip to content

Commit 29721df

Browse files
committed
[Backend Tester] Add index_put and index_select tests
ghstack-source-id: 9a2c93c ghstack-comment-id: 3116316749 Pull-Request: #12852
1 parent 04ee2df commit 29721df

File tree

2 files changed

+327
-0
lines changed

2 files changed

+327
-0
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
import torch
11+
from executorch.backends.test.suite.flow import TestFlow
12+
13+
from executorch.backends.test.suite.operators import (
14+
dtype_test,
15+
operator_test,
16+
OperatorTest,
17+
)
18+
19+
20+
class IndexPutModel(torch.nn.Module):
21+
def __init__(self, accumulate=False):
22+
super().__init__()
23+
self.accumulate = accumulate
24+
25+
def forward(self, x, indices, values):
26+
# Clone the input to avoid modifying it in-place
27+
result = x.clone()
28+
# Apply index_put_ and return the modified tensor
29+
result.index_put_(indices, values, self.accumulate)
30+
return result
31+
32+
33+
@operator_test
34+
class IndexPut(OperatorTest):
35+
@dtype_test
36+
def test_index_put_dtype(self, flow: TestFlow, dtype) -> None:
37+
indices = (torch.tensor([0, 2]),)
38+
values = torch.tensor([10.0, 20.0]).to(dtype)
39+
self._test_op(
40+
IndexPutModel(),
41+
((torch.rand(5, 2) * 100).to(dtype), indices, values),
42+
flow,
43+
generate_random_test_inputs=False,
44+
)
45+
46+
def test_index_put_basic(self, flow: TestFlow) -> None:
47+
indices = (torch.tensor([0, 2]),)
48+
values = torch.tensor([10.0, 20.0])
49+
self._test_op(
50+
IndexPutModel(),
51+
(torch.randn(5, 2), indices, values),
52+
flow,
53+
generate_random_test_inputs=False,
54+
)
55+
56+
def test_index_put_accumulate(self, flow: TestFlow) -> None:
57+
indices = (torch.tensor([0, 2]),)
58+
values = torch.tensor([10.0, 20.0])
59+
self._test_op(
60+
IndexPutModel(accumulate=False),
61+
(torch.ones(5, 2), indices, values),
62+
flow,
63+
generate_random_test_inputs=False,
64+
)
65+
66+
indices = (torch.tensor([0, 2]),)
67+
values = torch.tensor([10.0, 20.0])
68+
self._test_op(
69+
IndexPutModel(accumulate=True),
70+
(torch.ones(5, 2), indices, values),
71+
flow,
72+
generate_random_test_inputs=False,
73+
)
74+
75+
def test_index_put_shapes(self, flow: TestFlow) -> None:
76+
indices = (torch.tensor([0, 2]),)
77+
values = torch.tensor([10.0, 20.0])
78+
self._test_op(
79+
IndexPutModel(),
80+
(torch.randn(5), indices, values),
81+
flow,
82+
generate_random_test_inputs=False,
83+
)
84+
85+
indices = (torch.tensor([0, 2]), torch.tensor([1, 1]))
86+
values = torch.tensor([10.0, 20.0])
87+
self._test_op(
88+
IndexPutModel(),
89+
(torch.randn(5, 2), indices, values),
90+
flow,
91+
generate_random_test_inputs=False,
92+
)
93+
94+
indices = (torch.tensor([0, 2]), torch.tensor([1, 1]), torch.tensor([0, 1]))
95+
values = torch.tensor([10.0, 20.0])
96+
self._test_op(
97+
IndexPutModel(),
98+
(torch.randn(5, 3, 2), indices, values),
99+
flow,
100+
generate_random_test_inputs=False,
101+
)
102+
103+
indices = (
104+
torch.tensor([0, 2]),
105+
torch.tensor([1, 1]),
106+
torch.tensor([0, 1]),
107+
torch.tensor([2, 3]),
108+
)
109+
values = torch.tensor(
110+
[
111+
10.0,
112+
]
113+
)
114+
self._test_op(
115+
IndexPutModel(),
116+
(torch.randn(5, 3, 2, 4), indices, values),
117+
flow,
118+
generate_random_test_inputs=False,
119+
)
120+
121+
def test_index_put_indices(self, flow: TestFlow) -> None:
122+
indices = (torch.tensor([2]),)
123+
values = torch.tensor([10.0])
124+
self._test_op(
125+
IndexPutModel(),
126+
(torch.randn(5, 2), indices, values),
127+
flow,
128+
generate_random_test_inputs=False,
129+
)
130+
131+
indices = (torch.tensor([0, 2, 4]),)
132+
values = torch.tensor([10.0, 20.0, 30.0])
133+
self._test_op(
134+
IndexPutModel(),
135+
(torch.randn(5, 3), indices, values),
136+
flow,
137+
generate_random_test_inputs=False,
138+
)
139+
140+
indices = (torch.tensor([1, 1, 3, 3]),)
141+
values = torch.tensor([10.0, 20.0, 30.0, 40.0])
142+
self._test_op(
143+
IndexPutModel(accumulate=True),
144+
(torch.randn(5), indices, values),
145+
flow,
146+
generate_random_test_inputs=False,
147+
)
148+
149+
def test_index_put_edge_cases(self, flow: TestFlow) -> None:
150+
indices = (torch.tensor([0, 1, 2, 3, 4]),)
151+
values = torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0])
152+
self._test_op(
153+
IndexPutModel(),
154+
(torch.randn(5, 5), indices, values),
155+
flow,
156+
generate_random_test_inputs=False,
157+
)
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
import torch
11+
from executorch.backends.test.suite.flow import TestFlow
12+
13+
from executorch.backends.test.suite.operators import (
14+
dtype_test,
15+
operator_test,
16+
OperatorTest,
17+
)
18+
19+
20+
class IndexSelectModel(torch.nn.Module):
21+
def __init__(self, dim=0):
22+
super().__init__()
23+
self.dim = dim
24+
25+
def forward(self, x, indices):
26+
return torch.index_select(x, self.dim, indices)
27+
28+
29+
@operator_test
30+
class IndexSelect(OperatorTest):
31+
@dtype_test
32+
def test_index_select_dtype(self, flow: TestFlow, dtype) -> None:
33+
indices = torch.tensor([0, 2], dtype=torch.int64)
34+
self._test_op(
35+
IndexSelectModel(dim=0),
36+
((torch.rand(5, 3) * 100).to(dtype), indices),
37+
flow,
38+
generate_random_test_inputs=False,
39+
)
40+
41+
def test_index_select_basic(self, flow: TestFlow) -> None:
42+
indices = torch.tensor([0, 2], dtype=torch.int64)
43+
self._test_op(
44+
IndexSelectModel(dim=0),
45+
(torch.randn(5, 3), indices),
46+
flow,
47+
generate_random_test_inputs=False,
48+
)
49+
50+
def test_index_select_dimensions(self, flow: TestFlow) -> None:
51+
indices = torch.tensor([0, 2], dtype=torch.int64)
52+
self._test_op(
53+
IndexSelectModel(dim=0),
54+
(torch.randn(5, 3), indices),
55+
flow,
56+
generate_random_test_inputs=False,
57+
)
58+
59+
indices = torch.tensor([0, 1], dtype=torch.int64)
60+
self._test_op(
61+
IndexSelectModel(dim=1),
62+
(torch.randn(5, 3), indices),
63+
flow,
64+
generate_random_test_inputs=False,
65+
)
66+
67+
indices = torch.tensor([0, 2], dtype=torch.int64)
68+
self._test_op(
69+
IndexSelectModel(dim=2),
70+
(torch.randn(3, 4, 5), indices),
71+
flow,
72+
generate_random_test_inputs=False,
73+
)
74+
75+
def test_index_select_shapes(self, flow: TestFlow) -> None:
76+
indices = torch.tensor([0, 1], dtype=torch.int64)
77+
78+
self._test_op(
79+
IndexSelectModel(dim=0),
80+
(torch.randn(5), indices),
81+
flow,
82+
generate_random_test_inputs=False,
83+
)
84+
85+
self._test_op(
86+
IndexSelectModel(dim=0),
87+
(torch.randn(5, 3), indices),
88+
flow,
89+
generate_random_test_inputs=False,
90+
)
91+
92+
self._test_op(
93+
IndexSelectModel(dim=0),
94+
(torch.randn(5, 3, 2), indices),
95+
flow,
96+
generate_random_test_inputs=False,
97+
)
98+
99+
self._test_op(
100+
IndexSelectModel(dim=0),
101+
(torch.randn(5, 3, 2, 4), indices),
102+
flow,
103+
generate_random_test_inputs=False,
104+
)
105+
106+
def test_index_select_indices(self, flow: TestFlow) -> None:
107+
indices = torch.tensor([2], dtype=torch.int64)
108+
self._test_op(
109+
IndexSelectModel(dim=0),
110+
(torch.randn(5, 3), indices),
111+
flow,
112+
generate_random_test_inputs=False,
113+
)
114+
115+
indices = torch.tensor([0, 2, 4], dtype=torch.int64)
116+
self._test_op(
117+
IndexSelectModel(dim=0),
118+
(torch.randn(5, 3), indices),
119+
flow,
120+
generate_random_test_inputs=False,
121+
)
122+
123+
indices = torch.tensor([1, 1, 3, 3], dtype=torch.int64)
124+
self._test_op(
125+
IndexSelectModel(dim=0),
126+
(torch.randn(5, 3), indices),
127+
flow,
128+
generate_random_test_inputs=False,
129+
)
130+
131+
indices = torch.tensor([4, 3, 2, 1, 0], dtype=torch.int64)
132+
self._test_op(
133+
IndexSelectModel(dim=0),
134+
(torch.randn(5, 3), indices),
135+
flow,
136+
generate_random_test_inputs=False,
137+
)
138+
139+
def test_index_select_edge_cases(self, flow: TestFlow) -> None:
140+
indices = torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64)
141+
self._test_op(
142+
IndexSelectModel(dim=0),
143+
(torch.randn(5, 3), indices),
144+
flow,
145+
generate_random_test_inputs=False,
146+
)
147+
148+
indices = torch.tensor([0], dtype=torch.int64)
149+
self._test_op(
150+
IndexSelectModel(dim=0),
151+
(torch.randn(1, 3), indices),
152+
flow,
153+
generate_random_test_inputs=False,
154+
)
155+
156+
indices = torch.tensor([0, 1], dtype=torch.int64)
157+
self._test_op(
158+
IndexSelectModel(dim=0),
159+
(torch.zeros(5, 3), indices),
160+
flow,
161+
generate_random_test_inputs=False,
162+
)
163+
164+
indices = torch.tensor([0, 1], dtype=torch.int64)
165+
self._test_op(
166+
IndexSelectModel(dim=0),
167+
(torch.ones(5, 3), indices),
168+
flow,
169+
generate_random_test_inputs=False,
170+
)

0 commit comments

Comments
 (0)