Skip to content

Commit 62873b3

Browse files
Implement index_put inplace
Differential Revision: D74765433 Pull Request resolved: #10996
1 parent 71275e5 commit 62873b3

File tree

4 files changed

+278
-1
lines changed

4 files changed

+278
-1
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@
201201

202202
- op: index_put.out
203203

204+
- op: index_put_
205+
204206
- op: index_select.out
205207

206208
- op: index.Tensor_out

kernels/portable/cpu/op_index_put.cpp

Lines changed: 175 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,21 @@
1111

1212
#include <executorch/kernels/portable/cpu/util/advanced_index_util.h>
1313
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
14+
#include <executorch/runtime/core/exec_aten/util/tensor_shape_to_c_string.h>
1415
#include <executorch/runtime/kernel/kernel_includes.h>
1516

1617
namespace torch {
1718
namespace executor {
1819
namespace native {
1920

2021
using Tensor = executorch::aten::Tensor;
22+
using TensorOptList =
23+
executorch::aten::ArrayRef<executorch::aten::optional<Tensor>>;
2124

2225
Tensor& index_put_out(
2326
KernelRuntimeContext& ctx,
2427
const Tensor& in,
25-
executorch::aten::ArrayRef<executorch::aten::optional<Tensor>> indices,
28+
TensorOptList indices,
2629
const Tensor& values,
2730
const bool accumulate,
2831
Tensor& out) {
@@ -154,6 +157,177 @@ Tensor& index_put_out(
154157
return out;
155158
}
156159

160+
namespace {
161+
162+
bool check_special_case_in_place_args(
163+
Tensor& in,
164+
TensorOptList indices,
165+
const Tensor& values,
166+
const bool accumulate,
167+
size_t* dim) {
168+
ET_CHECK_OR_RETURN_FALSE(
169+
!accumulate,
170+
"Special case in-place index_put does not support accumulate");
171+
172+
ET_CHECK_OR_RETURN_FALSE(
173+
static_cast<ssize_t>(indices.size()) <= in.dim(),
174+
"Indexing too many dimensions");
175+
176+
bool found_index = false;
177+
for (const auto i : c10::irange(indices.size())) {
178+
if (indices[i].has_value()) {
179+
*dim = i;
180+
ET_CHECK_OR_RETURN_FALSE(
181+
!found_index,
182+
"Special case in-place index_put only supports a single non-null index tensor");
183+
found_index = true;
184+
const Tensor& index = indices[i].value();
185+
ScalarType ix_type = index.scalar_type();
186+
ET_CHECK_OR_RETURN_FALSE(
187+
ix_type == ScalarType::Long || ix_type == ScalarType::Int,
188+
"Special case in-place index_put only supports Long or Int index tensors; got %d",
189+
static_cast<int>(ix_type));
190+
ET_CHECK_OR_RETURN_FALSE(
191+
index.dim() == 1,
192+
"Special case in-place index_put only supports 1-dimensional index tensors; got %d",
193+
static_cast<int>(ix_type));
194+
}
195+
}
196+
197+
ET_CHECK_OR_RETURN_FALSE(
198+
found_index,
199+
"Special case in-place index_put needs at least one non-null index tensor");
200+
201+
const Tensor& index = indices[*dim].value();
202+
203+
bool is_valid_index = true;
204+
ET_SWITCH_TWO_TYPES(
205+
Long, Int, index.scalar_type(), ctx, "index_put_", CTYPE, [&]() {
206+
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
207+
for (const auto i : c10::irange(index.numel())) {
208+
if (index_arr[i] < 0 ||
209+
index_arr[i] >= static_cast<CTYPE>(in.size(*dim))) {
210+
ET_LOG(
211+
Error,
212+
"Index %" PRId64
213+
" out of range for tensor with size %zd"
214+
" at dimension %zu",
215+
static_cast<int64_t>(index_arr[i]),
216+
in.size(*dim),
217+
*dim);
218+
is_valid_index = false;
219+
break;
220+
}
221+
}
222+
});
223+
224+
ET_CHECK_OR_RETURN_FALSE(
225+
is_valid_index,
226+
"Some index values are not within bounds of input tensor at indexed dim");
227+
228+
ET_CHECK_OR_RETURN_FALSE(
229+
values.size(*dim) == index.size(0),
230+
"Special case in-place index_put requires values to match index length at the indexed dim; values.size(%zu) = %" ET_PRI_TENSOR_SIZE
231+
", index_length = %zd",
232+
*dim,
233+
values.size(*dim),
234+
index.size(0));
235+
236+
Tensor::SizesType expected_values_size[kTensorDimensionLimit] = {};
237+
size_t in_ndim = static_cast<size_t>(in.dim());
238+
for (const auto i : c10::irange(in_ndim)) {
239+
if (i != *dim) {
240+
expected_values_size[i] = static_cast<Tensor::SizesType>(in.size(i));
241+
}
242+
}
243+
expected_values_size[*dim] = static_cast<Tensor::SizesType>(index.size(0));
244+
245+
#if ET_LOG_ENABLED
246+
auto in_shape_str = executorch::runtime::tensor_shape_to_c_string(
247+
executorch::runtime::Span<const Tensor::SizesType>(
248+
in.sizes().data(), in.sizes().size()));
249+
auto values_shape_str = executorch::runtime::tensor_shape_to_c_string(
250+
executorch::runtime::Span<const Tensor::SizesType>(
251+
values.sizes().data(), values.sizes().size()));
252+
253+
ET_CHECK_OR_RETURN_FALSE(
254+
tensor_has_expected_size(values, {expected_values_size, in_ndim}),
255+
"Special case in-place index_put requires values to match input shape except for indexed dim; got input shape %s and values shape %s",
256+
in_shape_str.data(),
257+
values_shape_str.data());
258+
#else
259+
ET_CHECK_OR_RETURN_FALSE(
260+
tensor_has_expected_size(values, {expected_values_size, in_ndim}),
261+
"Special case in-place index_put requires values to match input shape except for indexed dim");
262+
#endif // ET_LOG_ENABLED
263+
264+
return true;
265+
}
266+
267+
} // namespace
268+
269+
Tensor& index_put_(
270+
KernelRuntimeContext& ctx,
271+
Tensor& in,
272+
TensorOptList indices,
273+
const Tensor& values,
274+
const bool accumulate) {
275+
(void)ctx;
276+
277+
ET_KERNEL_CHECK(
278+
ctx, tensors_have_same_dtype(in, values), InvalidArgument, in);
279+
280+
ET_KERNEL_CHECK(
281+
ctx, tensors_have_same_dim_order(in, values), InvalidArgument, in);
282+
283+
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, in);
284+
285+
size_t dim = 0;
286+
ET_KERNEL_CHECK(
287+
ctx,
288+
check_special_case_in_place_args(in, indices, values, accumulate, &dim),
289+
InvalidArgument,
290+
in);
291+
292+
const Tensor& index = indices[dim].value();
293+
ScalarType index_type = index.scalar_type();
294+
295+
if (in.dim() == 0) {
296+
memcpy(in.mutable_data_ptr(), values.const_data_ptr(), in.nbytes());
297+
return in;
298+
}
299+
300+
size_t leading_dims = getLeadingDims(in, dim);
301+
size_t trailing_dims = getTrailingDims(in, dim);
302+
303+
if (leading_dims == 0 || trailing_dims == 0) {
304+
return in;
305+
}
306+
307+
size_t values_dim_length = values.size(dim);
308+
size_t in_dim_length = in.size(dim);
309+
310+
size_t length_per_step = trailing_dims * in.element_size();
311+
312+
const char* values_data = values.const_data_ptr<char>();
313+
char* in_data = in.mutable_data_ptr<char>();
314+
315+
ET_SWITCH_TWO_TYPES(Long, Int, index_type, ctx, "index_put_", CTYPE, [&]() {
316+
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
317+
for (const auto i : c10::irange(leading_dims)) {
318+
const char* src = values_data + i * values_dim_length * length_per_step;
319+
char* dest = in_data + i * in_dim_length * length_per_step;
320+
for (const auto j : c10::irange(values_dim_length)) {
321+
const char* copy_src = src + j * length_per_step;
322+
char* copy_dest = dest + index_arr[j] * length_per_step;
323+
memcpy(copy_dest, copy_src, length_per_step);
324+
}
325+
}
326+
});
327+
328+
return in;
329+
}
330+
157331
} // namespace native
158332
} // namespace executor
159333
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,11 @@
452452
- arg_meta: null
453453
kernel_name: torch::executor::index_put_out
454454

455+
- op: index_put_
456+
kernels:
457+
- arg_meta: null
458+
kernel_name: torch::executor::index_put_
459+
455460
- op: index_select.out
456461
kernels:
457462
- arg_meta: null

kernels/test/op_index_put_test.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,3 +1011,99 @@ TEST_F(OpIndexPutOutTest, DynamicShapeUnbound) {
10111011
test_dynamic_shape(
10121012
{1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
10131013
}
1014+
1015+
class OpIndexPutInplaceTest : public OperatorTest {
1016+
protected:
1017+
Tensor& op_index_put_(
1018+
Tensor& input,
1019+
OptTensorArrayRef indices,
1020+
const Tensor& values,
1021+
const bool accumulate) {
1022+
#ifdef USE_ATEN_LIB
1023+
c10::List<std::optional<at::Tensor>> indices_list(indices);
1024+
return torch::executor::aten::index_put_(
1025+
context_, input, indices_list, values, accumulate);
1026+
#else
1027+
return torch::executor::aten::index_put_(
1028+
context_, input, indices, values, accumulate);
1029+
#endif
1030+
}
1031+
1032+
template <
1033+
executorch::aten::ScalarType INPUT_DTYPE,
1034+
executorch::aten::ScalarType INDICES_DTYPE>
1035+
void test_dtype() {
1036+
TensorFactory<INPUT_DTYPE> tf;
1037+
TensorFactory<INDICES_DTYPE> tfl;
1038+
1039+
// clang-format off
1040+
Tensor x = tf.make(
1041+
{2, 3, 4},
1042+
{
1043+
// [0, :, :]
1044+
1, 1, 1, 1, // [0, 0, :]
1045+
0, 0, 0, 0, // [0, 1, :]
1046+
2, 2, 2, 2, // [0, 2, :]
1047+
1048+
// [1, :, :]
1049+
3, 3, 3, 3, // [0, 0, :]
1050+
0, 0, 0, 0, // [0, 1, :]
1051+
5, 5, 5, 5, // [0, 2, :]
1052+
});
1053+
// clang-format on
1054+
1055+
optional<Tensor> indices[] = {
1056+
optional<Tensor>(),
1057+
optional<Tensor>(tfl.make({2}, {0, 2})),
1058+
};
1059+
1060+
// clang-format off
1061+
Tensor values = tf.make(
1062+
{2, 2, 4},
1063+
{
1064+
// [0, :, :]
1065+
1, 2, 3, 4, // [0, 0, :]
1066+
5, 6, 7, 8, // [0, 1, :]
1067+
1068+
// [1, :, :]
1069+
9, 10, 11, 12, // [0, 0, :]
1070+
13, 14, 15, 16, // [0, 1, :]
1071+
});
1072+
// clang-format on
1073+
1074+
// clang-format off
1075+
Tensor expected = tf.make(
1076+
{2, 3, 4},
1077+
{
1078+
// [0, :, :]
1079+
1, 2, 3, 4, // [0, 0, :]
1080+
0, 0, 0, 0, // [0, 1, :]
1081+
5, 6, 7, 8, // [0, 2, :]
1082+
1083+
// [1, :, :]
1084+
9, 10, 11, 12, // [0, 0, :]
1085+
0, 0, 0, 0, // [0, 1, :]
1086+
13, 14, 15, 16, // [0, 2, :]
1087+
});
1088+
// clang-format on
1089+
1090+
Tensor ret = op_index_put_(x, indices, values, /*accumulate=*/false);
1091+
1092+
EXPECT_TENSOR_EQ(ret, x);
1093+
EXPECT_TENSOR_EQ(ret, expected);
1094+
}
1095+
};
1096+
1097+
TEST_F(OpIndexPutInplaceTest, AllDtypesSupportedForInput) {
1098+
#define TEST_ENTRY(ctype, dtype) \
1099+
test_dtype<ScalarType::dtype, ScalarType::Long>();
1100+
1101+
ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
1102+
1103+
#undef TEST_ENTRY
1104+
}
1105+
1106+
TEST_F(OpIndexPutInplaceTest, AllDtypesSupportedForIndicesList) {
1107+
test_dtype<ScalarType::Float, ScalarType::Long>();
1108+
test_dtype<ScalarType::Float, ScalarType::Int>();
1109+
}

0 commit comments

Comments
 (0)