Skip to content

Commit 9b0e776

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Implement index_put inplace (#10996)
Summary: Pull Request resolved: #10996 Reviewed By: JacobSzwejbka Differential Revision: D74765433
1 parent 1b063ca commit 9b0e776

File tree

4 files changed

+277
-1
lines changed

4 files changed

+277
-1
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@
201201

202202
- op: index_put.out
203203

204+
- op: index_put_
205+
204206
- op: index_select.out
205207

206208
- op: index.Tensor_out

kernels/portable/cpu/op_index_put.cpp

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,21 @@
1111

1212
#include <executorch/kernels/portable/cpu/util/advanced_index_util.h>
1313
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
14+
#include <executorch/runtime/core/exec_aten/util/tensor_shape_to_c_string.h>
1415
#include <executorch/runtime/kernel/kernel_includes.h>
1516

1617
namespace torch {
1718
namespace executor {
1819
namespace native {
1920

2021
using Tensor = executorch::aten::Tensor;
22+
using TensorOptList =
23+
executorch::aten::ArrayRef<executorch::aten::optional<Tensor>>;
2124

2225
Tensor& index_put_out(
2326
KernelRuntimeContext& ctx,
2427
const Tensor& in,
25-
executorch::aten::ArrayRef<executorch::aten::optional<Tensor>> indices,
28+
TensorOptList indices,
2629
const Tensor& values,
2730
const bool accumulate,
2831
Tensor& out) {
@@ -154,6 +157,175 @@ Tensor& index_put_out(
154157
return out;
155158
}
156159

160+
namespace {
161+
162+
bool check_special_case_in_place_args(
163+
Tensor& in,
164+
TensorOptList indices,
165+
const Tensor& values,
166+
const bool accumulate,
167+
size_t* dim) {
168+
ET_CHECK_OR_RETURN_FALSE(
169+
!accumulate,
170+
"Special case in-place index_put does not support accumulate");
171+
172+
ET_CHECK_OR_RETURN_FALSE(
173+
static_cast<ssize_t>(indices.size()) <= in.dim(),
174+
"Indexing too many dimensions");
175+
176+
bool found_index = false;
177+
for (const auto i : c10::irange(indices.size())) {
178+
if (indices[i].has_value()) {
179+
*dim = i;
180+
ET_CHECK_OR_RETURN_FALSE(
181+
!found_index,
182+
"Special case in-place index_put only supports a single non-null index tensor");
183+
found_index = true;
184+
const Tensor& index = indices[i].value();
185+
ScalarType ix_type = index.scalar_type();
186+
ET_CHECK_OR_RETURN_FALSE(
187+
ix_type == ScalarType::Long || ix_type == ScalarType::Int,
188+
"Special case in-place index_put only supports Long or Int index tensors; got %d",
189+
static_cast<int>(ix_type));
190+
ET_CHECK_OR_RETURN_FALSE(
191+
index.dim() == 1,
192+
"Special case in-place index_put only supports 1-dimensional index tensors; got %d",
193+
static_cast<int>(ix_type));
194+
}
195+
}
196+
197+
ET_CHECK_OR_RETURN_FALSE(
198+
found_index,
199+
"Special case in-place index_put needs at least one non-null index tensor");
200+
201+
const Tensor& index = indices[*dim].value();
202+
203+
bool is_valid_index = true;
204+
ET_SWITCH_TWO_TYPES(Long, Int, index.scalar_type(), ctx, "index_put_", CTYPE, [&]() {
205+
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
206+
for (const auto i : c10::irange(index.numel())) {
207+
if (index_arr[i] < 0 || index_arr[i] >= static_cast<CTYPE>(in.size(*dim))) {
208+
ET_LOG(
209+
Error,
210+
"Index %" PRId64 " out of range for tensor with size %zd"
211+
" at dimension %zu",
212+
static_cast<int64_t>(index_arr[i]), in.size(*dim), *dim);
213+
is_valid_index = false;
214+
break;
215+
}
216+
}
217+
});
218+
219+
ET_CHECK_OR_RETURN_FALSE(
220+
is_valid_index,
221+
"Some index values are not within bounds of input tensor at indexed dim");
222+
223+
ET_CHECK_OR_RETURN_FALSE(
224+
values.size(*dim) == index.size(0),
225+
"Special case in-place index_put requires values to match index length at the indexed dim; values.size(%zu) = %" ET_PRI_TENSOR_SIZE
226+
", index_length = %zd",
227+
*dim,
228+
values.size(*dim),
229+
index.size(0));
230+
231+
Tensor::SizesType expected_values_size[kTensorDimensionLimit] = {};
232+
size_t expected_values_dim = 0;
233+
for (const auto i : c10::irange(in.dim())) {
234+
if (i != *dim) {
235+
expected_values_size[i] = static_cast<Tensor::SizesType>(in.size(i));
236+
}
237+
}
238+
expected_values_size[*dim] = static_cast<Tensor::SizesType>(index.size(0));
239+
expected_values_dim = in.dim();
240+
241+
#if ET_LOG_ENABLED
242+
auto in_shape_str = executorch::runtime::tensor_shape_to_c_string(
243+
executorch::runtime::Span<const Tensor::SizesType>(
244+
in.sizes().data(), in.sizes().size()));
245+
auto values_shape_str = executorch::runtime::tensor_shape_to_c_string(
246+
executorch::runtime::Span<const Tensor::SizesType>(
247+
values.sizes().data(), values.sizes().size()));
248+
249+
ET_CHECK_OR_RETURN_FALSE(
250+
tensor_has_expected_size(
251+
values, {expected_values_size, expected_values_dim}),
252+
"Special case in-place index_put requires values to match input shape except for indexed dim; got input shape %s and values shape %s",
253+
in_shape_str.data(),
254+
values_shape_str.data());
255+
#else
256+
ET_CHECK_OR_RETURN_FALSE(
257+
tensor_has_expected_size(
258+
values, {expected_values_size, expected_values_dim}),
259+
"Special case in-place index_put requires values to match input shape except for indexed dim");
260+
#endif // ET_LOG_ENABLED
261+
262+
return true;
263+
}
264+
265+
} // namespace
266+
267+
Tensor& index_put_(
268+
KernelRuntimeContext& ctx,
269+
Tensor& in,
270+
TensorOptList indices,
271+
const Tensor& values,
272+
const bool accumulate) {
273+
(void)ctx;
274+
275+
ET_KERNEL_CHECK(
276+
ctx, tensors_have_same_dtype(in, values), InvalidArgument, in);
277+
278+
ET_KERNEL_CHECK(
279+
ctx, tensors_have_same_dim_order(in, values), InvalidArgument, in);
280+
281+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, in);
282+
283+
size_t dim = 0;
284+
ET_KERNEL_CHECK(
285+
ctx,
286+
check_special_case_in_place_args(in, indices, values, accumulate, &dim),
287+
InvalidArgument,
288+
in);
289+
290+
const Tensor& index = indices[dim].value();
291+
ScalarType index_type = index.scalar_type();
292+
293+
if (in.dim() == 0) {
294+
memcpy(in.mutable_data_ptr(), values.const_data_ptr(), in.nbytes());
295+
return in;
296+
}
297+
298+
size_t leading_dims = getLeadingDims(in, dim);
299+
size_t trailing_dims = getTrailingDims(in, dim);
300+
301+
if (leading_dims == 0 || trailing_dims == 0) {
302+
return in;
303+
}
304+
305+
size_t values_dim_length = values.size(dim);
306+
size_t in_dim_length = in.size(dim);
307+
308+
size_t length_per_step = trailing_dims * in.element_size();
309+
310+
const char* values_data = values.const_data_ptr<char>();
311+
char* in_data = in.mutable_data_ptr<char>();
312+
313+
ET_SWITCH_TWO_TYPES(Long, Int, index_type, ctx, "index_put_", CTYPE, [&]() {
314+
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
315+
for (const auto i : c10::irange(leading_dims)) {
316+
const char* src = values_data + i * values_dim_length * length_per_step;
317+
char* dest = in_data + i * in_dim_length * length_per_step;
318+
for (const auto j : c10::irange(values_dim_length)) {
319+
const char* copy_src = src + j * length_per_step;
320+
char* copy_dest = dest + index_arr[j] * length_per_step;
321+
memcpy(copy_dest, copy_src, length_per_step);
322+
}
323+
}
324+
});
325+
326+
return in;
327+
}
328+
157329
} // namespace native
158330
} // namespace executor
159331
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,11 @@
452452
- arg_meta: null
453453
kernel_name: torch::executor::index_put_out
454454

455+
- op: index_put_
456+
kernels:
457+
- arg_meta: null
458+
kernel_name: torch::executor::index_put_
459+
455460
- op: index_select.out
456461
kernels:
457462
- arg_meta: null

kernels/test/op_index_put_test.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,3 +1011,100 @@ TEST_F(OpIndexPutOutTest, DynamicShapeUnbound) {
10111011
test_dynamic_shape(
10121012
{1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
10131013
}
1014+
1015+
class OpIndexPutInplaceTest : public OperatorTest {
1016+
protected:
1017+
Tensor& op_index_put_(
1018+
Tensor& input,
1019+
OptTensorArrayRef indices,
1020+
const Tensor& values,
1021+
const bool accumulate) {
1022+
#ifdef USE_ATEN_LIB
1023+
c10::List<std::optional<at::Tensor>> indices_list(indices);
1024+
return torch::executor::aten::index_put_(
1025+
context_, input, indices_list, values, accumulate);
1026+
#else
1027+
return torch::executor::aten::index_put_(
1028+
context_, input, indices, values, accumulate);
1029+
#endif
1030+
}
1031+
1032+
template <
1033+
executorch::aten::ScalarType INPUT_DTYPE,
1034+
executorch::aten::ScalarType INDICES_DTYPE>
1035+
void test_dtype() {
1036+
TensorFactory<INPUT_DTYPE> tf;
1037+
TensorFactory<INDICES_DTYPE> tfl;
1038+
1039+
// clang-format off
1040+
Tensor x = tf.make(
1041+
{2, 3, 4},
1042+
{
1043+
// [0, :, :]
1044+
1, 1, 1, 1, // [0, 0, :]
1045+
0, 0, 0, 0, // [0, 1, :]
1046+
2, 2, 2, 2, // [0, 2, :]
1047+
1048+
// [1, :, :]
1049+
3, 3, 3, 3, // [0, 0, :]
1050+
0, 0, 0, 0, // [0, 1, :]
1051+
5, 5, 5, 5, // [0, 2, :]
1052+
});
1053+
// clang-format on
1054+
1055+
optional<Tensor> indices[] = {
1056+
optional<Tensor>(),
1057+
optional<Tensor>(tfl.make({2}, {0, 2})),
1058+
};
1059+
1060+
// clang-format off
1061+
Tensor values = tf.make(
1062+
{2, 2, 4},
1063+
{
1064+
// [0, :, :]
1065+
1, 2, 3, 4, // [0, 0, :]
1066+
5, 6, 7, 8, // [0, 1, :]
1067+
1068+
// [1, :, :]
1069+
9, 10, 11, 12, // [0, 0, :]
1070+
13, 14, 15, 16, // [0, 1, :]
1071+
});
1072+
// clang-format on
1073+
1074+
// clang-format off
1075+
Tensor expected = tf.make(
1076+
{2, 3, 4},
1077+
{
1078+
// [0, :, :]
1079+
1, 2, 3, 4, // [0, 0, :]
1080+
0, 0, 0, 0, // [0, 1, :]
1081+
5, 6, 7, 8, // [0, 2, :]
1082+
1083+
// [1, :, :]
1084+
9, 10, 11, 12, // [0, 0, :]
1085+
0, 0, 0, 0, // [0, 1, :]
1086+
13, 14, 15, 16, // [0, 2, :]
1087+
});
1088+
// clang-format on
1089+
1090+
Tensor ret =
1091+
op_index_put_(x, indices, values, /*accumulate=*/false);
1092+
1093+
EXPECT_TENSOR_EQ(ret, x);
1094+
EXPECT_TENSOR_EQ(ret, expected);
1095+
}
1096+
};
1097+
1098+
TEST_F(OpIndexPutInplaceTest, AllDtypesSupportedForInput) {
1099+
#define TEST_ENTRY(ctype, dtype) \
1100+
test_dtype<ScalarType::dtype, ScalarType::Long>();
1101+
1102+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
1103+
1104+
#undef TEST_ENTRY
1105+
}
1106+
1107+
TEST_F(OpIndexPutInplaceTest, AllDtypesSupportedForIndicesList) {
1108+
test_dtype<ScalarType::Float, ScalarType::Long>();
1109+
test_dtype<ScalarType::Float, ScalarType::Int>();
1110+
}

0 commit comments

Comments
 (0)