Skip to content

Commit 208d7df

Browse files
committed
[Backend Tester] Add index_put and index_select tests
ghstack-source-id: 4f4546d ghstack-comment-id: 3116316749 Pull-Request: #12852
1 parent 2ec4f59 commit 208d7df

File tree

3 files changed

+371
-1
lines changed

3 files changed

+371
-1
lines changed

backends/test/harness/tester.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,10 @@ def run_method_and_compare_outputs(
311311
print(f"Comparing Stage {stage} with Stage {reference_stage}")
312312
for run_iteration in range(number_of_runs):
313313
inputs_to_run = inputs if inputs else next(self.generate_random_inputs())
314-
input_shapes = [generated_input.shape for generated_input in inputs_to_run]
314+
input_shapes = [
315+
generated_input.shape if hasattr(generated_input, "shape") else None
316+
for generated_input in inputs_to_run
317+
]
315318
print(f"Run {run_iteration} with input shapes: {input_shapes}")
316319

317320
# Reference output (and quantization scale)
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
import torch
11+
from executorch.backends.test.suite.flow import TestFlow
12+
13+
from executorch.backends.test.suite.operators import (
14+
dtype_test,
15+
operator_test,
16+
OperatorTest,
17+
)
18+
19+
20+
class IndexPutModel(torch.nn.Module):
21+
def __init__(self, accumulate=False):
22+
super().__init__()
23+
self.accumulate = accumulate
24+
25+
def forward(self, x, indices, values):
26+
# Clone the input to avoid modifying it in-place
27+
result = x.clone()
28+
# Apply index_put_ and return the modified tensor
29+
result.index_put_(indices, values, self.accumulate)
30+
return result
31+
32+
33+
@operator_test
34+
class IndexPut(OperatorTest):
35+
@dtype_test
36+
def test_index_put_dtype(self, flow: TestFlow, dtype) -> None:
37+
indices = (torch.tensor([0, 2]),)
38+
values = torch.tensor([10.0, 20.0]).to(dtype)
39+
self._test_op(
40+
IndexPutModel(),
41+
((torch.rand(5, 2) * 100).to(dtype), indices, values),
42+
flow,
43+
generate_random_test_inputs=False,
44+
)
45+
46+
def test_index_put_accumulate(self, flow: TestFlow) -> None:
47+
indices = (torch.tensor([0, 2]),)
48+
values = torch.tensor([10.0, 20.0])
49+
self._test_op(
50+
IndexPutModel(accumulate=False),
51+
(torch.ones(5, 2), indices, values),
52+
flow,
53+
generate_random_test_inputs=False,
54+
)
55+
56+
indices = (torch.tensor([0, 2]),)
57+
values = torch.tensor([10.0, 20.0])
58+
self._test_op(
59+
IndexPutModel(accumulate=True),
60+
(torch.ones(5, 2), indices, values),
61+
flow,
62+
generate_random_test_inputs=False,
63+
)
64+
65+
def test_index_put_shapes(self, flow: TestFlow) -> None:
66+
indices = (torch.tensor([0, 2]),)
67+
values = torch.tensor([10.0, 20.0])
68+
self._test_op(
69+
IndexPutModel(),
70+
(torch.randn(5), indices, values),
71+
flow,
72+
generate_random_test_inputs=False,
73+
)
74+
75+
indices = (torch.tensor([0, 2]), torch.tensor([1, 1]))
76+
values = torch.tensor([10.0, 20.0])
77+
self._test_op(
78+
IndexPutModel(),
79+
(torch.randn(5, 2), indices, values),
80+
flow,
81+
generate_random_test_inputs=False,
82+
)
83+
84+
indices = (torch.tensor([0, 2]), torch.tensor([1, 1]), torch.tensor([0, 1]))
85+
values = torch.tensor([10.0, 20.0])
86+
self._test_op(
87+
IndexPutModel(),
88+
(torch.randn(5, 3, 2), indices, values),
89+
flow,
90+
generate_random_test_inputs=False,
91+
)
92+
93+
indices = (
94+
torch.tensor([0, 2]),
95+
torch.tensor([1, 1]),
96+
torch.tensor([0, 1]),
97+
torch.tensor([2, 3]),
98+
)
99+
values = torch.tensor(
100+
[
101+
10.0,
102+
]
103+
)
104+
self._test_op(
105+
IndexPutModel(),
106+
(torch.randn(5, 3, 2, 4), indices, values),
107+
flow,
108+
generate_random_test_inputs=False,
109+
)
110+
111+
def test_index_put_indices(self, flow: TestFlow) -> None:
112+
indices = (torch.tensor([2]),)
113+
values = torch.tensor([10.0])
114+
self._test_op(
115+
IndexPutModel(),
116+
(torch.randn(5, 2), indices, values),
117+
flow,
118+
generate_random_test_inputs=False,
119+
)
120+
121+
indices = (torch.tensor([0, 2, 4]),)
122+
values = torch.tensor([10.0, 20.0, 30.0])
123+
self._test_op(
124+
IndexPutModel(),
125+
(torch.randn(5, 3), indices, values),
126+
flow,
127+
generate_random_test_inputs=False,
128+
)
129+
130+
indices = (torch.tensor([1, 1, 3, 3]),)
131+
values = torch.tensor([10.0, 20.0, 30.0, 40.0])
132+
self._test_op(
133+
IndexPutModel(accumulate=True),
134+
(torch.randn(5), indices, values),
135+
flow,
136+
generate_random_test_inputs=False,
137+
)
138+
139+
def test_index_put_broadcasting(self, flow: TestFlow) -> None:
140+
# Test scalar broadcasting - single value to multiple positions
141+
indices = (torch.tensor([0, 2, 4]),)
142+
values = torch.tensor([42.0])
143+
self._test_op(
144+
IndexPutModel(),
145+
(torch.randn(5, 3), indices, values),
146+
flow,
147+
generate_random_test_inputs=False,
148+
)
149+
150+
# Test 1D broadcasting to 2D indexed positions
151+
indices = (torch.tensor([0, 1]), torch.tensor([1, 2]))
152+
values = torch.tensor([10.0, 20.0]) # 1D tensor
153+
self._test_op(
154+
IndexPutModel(),
155+
(torch.randn(3, 4), indices, values),
156+
flow,
157+
generate_random_test_inputs=False,
158+
)
159+
160+
# Test broadcasting with compatible shapes - 1D to multiple 2D slices
161+
indices = (torch.tensor([0, 2]),)
162+
values = torch.tensor([5.0, 15.0]) # Will broadcast to (2, 3) shape
163+
self._test_op(
164+
IndexPutModel(),
165+
(torch.randn(4, 2), indices, values),
166+
flow,
167+
generate_random_test_inputs=False,
168+
)
169+
170+
# Test 2D values broadcasting to 3D indexed positions
171+
indices = (torch.tensor([0, 1]),)
172+
values = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) # 2D tensor
173+
self._test_op(
174+
IndexPutModel(),
175+
(torch.randn(3, 2, 2), indices, values),
176+
flow,
177+
generate_random_test_inputs=False,
178+
)
179+
180+
# Test broadcasting with accumulate=True
181+
indices = (torch.tensor([1, 1, 1]),)
182+
values = torch.tensor([5.0]) # Scalar will be added 3 times to same position
183+
self._test_op(
184+
IndexPutModel(accumulate=True),
185+
(torch.ones(4, 2), indices, values),
186+
flow,
187+
generate_random_test_inputs=False,
188+
)
189+
190+
def test_index_put_two_indices(self, flow: TestFlow) -> None:
191+
# Test basic two-index tensor indexing
192+
indices = (torch.tensor([0, 1, 2]), torch.tensor([1, 0, 2]))
193+
values = torch.tensor([10.0, 20.0, 30.0])
194+
self._test_op(
195+
IndexPutModel(),
196+
(torch.randn(4, 3), indices, values),
197+
flow,
198+
generate_random_test_inputs=False,
199+
)
200+
201+
# Test two-index with different lengths (broadcasting)
202+
indices = (torch.tensor([0, 2]), torch.tensor([1, 1]))
203+
values = torch.tensor([15.0, 25.0])
204+
self._test_op(
205+
IndexPutModel(),
206+
(torch.randn(3, 3), indices, values),
207+
flow,
208+
generate_random_test_inputs=False,
209+
)
210+
211+
# Test two-index with repeated positions and accumulate=True
212+
indices = (torch.tensor([1, 1, 2]), torch.tensor([0, 0, 1]))
213+
values = torch.tensor([5.0, 10.0, 15.0])
214+
self._test_op(
215+
IndexPutModel(accumulate=True),
216+
(torch.zeros(3, 2), indices, values),
217+
flow,
218+
generate_random_test_inputs=False,
219+
)
220+
221+
# Test two-index with repeated positions and accumulate=False
222+
indices = (torch.tensor([1, 1, 2]), torch.tensor([0, 0, 1]))
223+
values = torch.tensor([5.0, 10.0, 15.0])
224+
self._test_op(
225+
IndexPutModel(accumulate=False),
226+
(torch.zeros(3, 2), indices, values),
227+
flow,
228+
generate_random_test_inputs=False,
229+
)
230+
231+
# Test two-index with index broadcast.
232+
indices = (torch.tensor([1]), torch.tensor([0, 0, 1]))
233+
values = torch.tensor([5.0, 10.0, 15.0])
234+
self._test_op(
235+
IndexPutModel(accumulate=False),
236+
(torch.zeros(3, 2), indices, values),
237+
flow,
238+
generate_random_test_inputs=False,
239+
)
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
10+
import torch
11+
from executorch.backends.test.suite.flow import TestFlow
12+
13+
from executorch.backends.test.suite.operators import (
14+
dtype_test,
15+
operator_test,
16+
OperatorTest,
17+
)
18+
19+
20+
class IndexSelectModel(torch.nn.Module):
21+
def __init__(self, dim=0):
22+
super().__init__()
23+
self.dim = dim
24+
25+
def forward(self, x, indices):
26+
return torch.index_select(x, self.dim, indices)
27+
28+
29+
@operator_test
30+
class IndexSelect(OperatorTest):
31+
@dtype_test
32+
def test_index_select_dtype(self, flow: TestFlow, dtype) -> None:
33+
indices = torch.tensor([0, 2], dtype=torch.int64)
34+
self._test_op(
35+
IndexSelectModel(dim=0),
36+
((torch.rand(5, 3) * 100).to(dtype), indices),
37+
flow,
38+
generate_random_test_inputs=False,
39+
)
40+
41+
def test_index_select_dimensions(self, flow: TestFlow) -> None:
42+
indices = torch.tensor([0, 2], dtype=torch.int64)
43+
self._test_op(
44+
IndexSelectModel(dim=0),
45+
(torch.randn(5, 3), indices),
46+
flow,
47+
generate_random_test_inputs=False,
48+
)
49+
50+
indices = torch.tensor([0, 1], dtype=torch.int64)
51+
self._test_op(
52+
IndexSelectModel(dim=1),
53+
(torch.randn(5, 3), indices),
54+
flow,
55+
generate_random_test_inputs=False,
56+
)
57+
58+
indices = torch.tensor([0, 2], dtype=torch.int64)
59+
self._test_op(
60+
IndexSelectModel(dim=2),
61+
(torch.randn(3, 4, 5), indices),
62+
flow,
63+
generate_random_test_inputs=False,
64+
)
65+
66+
def test_index_select_shapes(self, flow: TestFlow) -> None:
67+
indices = torch.tensor([0, 1], dtype=torch.int64)
68+
69+
self._test_op(
70+
IndexSelectModel(dim=0),
71+
(torch.randn(5), indices),
72+
flow,
73+
generate_random_test_inputs=False,
74+
)
75+
76+
self._test_op(
77+
IndexSelectModel(dim=0),
78+
(torch.randn(5, 3), indices),
79+
flow,
80+
generate_random_test_inputs=False,
81+
)
82+
83+
self._test_op(
84+
IndexSelectModel(dim=0),
85+
(torch.randn(5, 3, 2), indices),
86+
flow,
87+
generate_random_test_inputs=False,
88+
)
89+
90+
self._test_op(
91+
IndexSelectModel(dim=0),
92+
(torch.randn(5, 3, 2, 4), indices),
93+
flow,
94+
generate_random_test_inputs=False,
95+
)
96+
97+
def test_index_select_indices(self, flow: TestFlow) -> None:
98+
indices = torch.tensor([2], dtype=torch.int64)
99+
self._test_op(
100+
IndexSelectModel(dim=0),
101+
(torch.randn(5, 3), indices),
102+
flow,
103+
generate_random_test_inputs=False,
104+
)
105+
106+
indices = torch.tensor([0, 2, 4], dtype=torch.int64)
107+
self._test_op(
108+
IndexSelectModel(dim=0),
109+
(torch.randn(5, 3), indices),
110+
flow,
111+
generate_random_test_inputs=False,
112+
)
113+
114+
indices = torch.tensor([1, 1, 3, 3], dtype=torch.int64)
115+
self._test_op(
116+
IndexSelectModel(dim=0),
117+
(torch.randn(5, 3), indices),
118+
flow,
119+
generate_random_test_inputs=False,
120+
)
121+
122+
indices = torch.tensor([4, 3, 2, 1, 0], dtype=torch.int64)
123+
self._test_op(
124+
IndexSelectModel(dim=0),
125+
(torch.randn(5, 3), indices),
126+
flow,
127+
generate_random_test_inputs=False,
128+
)

0 commit comments

Comments
 (0)