Skip to content

Commit 8d0dbc2

Browse files
Add fast path to index_out
Differential Revision: D78455813 Pull Request resolved: #12570
1 parent 7a37676 commit 8d0dbc2

File tree

2 files changed

+198
-2
lines changed

2 files changed

+198
-2
lines changed

kernels/portable/cpu/op_index.cpp

Lines changed: 140 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,159 @@ namespace native {
2222
using Tensor = executorch::aten::Tensor;
2323
using TensorOptList = executorch::aten::ArrayRef<std::optional<Tensor>>;
2424

25-
Tensor& index_Tensor_out(
25+
namespace {
26+
27+
bool check_fast_path_conditions(
28+
ET_UNUSED const Tensor& in,
29+
TensorOptList indices,
30+
size_t* dim) {
31+
bool found_index = false;
32+
for (const auto i : c10::irange(indices.size())) {
33+
if (indices[i].has_value()) {
34+
*dim = i;
35+
// Fast path only supports a single non-null index tensor
36+
if (found_index) {
37+
return false;
38+
}
39+
found_index = true;
40+
const Tensor& index = indices[i].value();
41+
ScalarType ix_type = index.scalar_type();
42+
// Fast path only supports only supports Long or Int index tensors
43+
if (ix_type != ScalarType::Long && ix_type != ScalarType::Int) {
44+
return false;
45+
}
46+
// Fast path only supports a 1-dimensional index tensor
47+
if (index.dim() != 1) {
48+
return false;
49+
}
50+
}
51+
}
52+
53+
// Fast path only supports needs at least one non-null index tensor
54+
if (!found_index) {
55+
return false;
56+
}
57+
58+
return true;
59+
}
60+
61+
bool check_fast_path_args(
62+
const Tensor& in,
63+
TensorOptList indices,
64+
size_t dim,
65+
Tensor& out) {
66+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
67+
68+
ET_CHECK_OR_RETURN_FALSE(
69+
static_cast<ssize_t>(indices.size()) <= in.dim(),
70+
"Indexing too many dimensions");
71+
72+
const Tensor& index = indices[dim].value();
73+
74+
bool is_valid_index = true;
75+
ET_SWITCH_TWO_TYPES(
76+
Long, Int, index.scalar_type(), ctx, "index_put_", CTYPE, [&]() {
77+
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
78+
for (const auto i : c10::irange(index.numel())) {
79+
if (index_arr[i] < 0 ||
80+
index_arr[i] >= static_cast<CTYPE>(in.size(dim))) {
81+
ET_LOG(
82+
Error,
83+
"Index %" PRId64
84+
" out of range for tensor with size %zd"
85+
" at dimension %zu",
86+
static_cast<int64_t>(index_arr[i]),
87+
in.size(dim),
88+
dim);
89+
is_valid_index = false;
90+
break;
91+
}
92+
}
93+
});
94+
95+
ET_CHECK_OR_RETURN_FALSE(
96+
is_valid_index,
97+
"Some index values are not within bounds of input tensor at indexed dim");
98+
99+
return true;
100+
}
101+
102+
Tensor& fast_path(
26103
KernelRuntimeContext& ctx,
27104
const Tensor& in,
28105
TensorOptList indices,
106+
size_t dim,
29107
Tensor& out) {
30108
(void)ctx;
31109

32110
ET_KERNEL_CHECK(
33-
ctx, check_index_args(in, indices, out), InvalidArgument, out);
111+
ctx, check_fast_path_args(in, indices, dim, out), InvalidArgument, out);
112+
113+
const Tensor& index = indices[dim].value();
114+
ScalarType index_type = index.scalar_type();
115+
116+
if (out.dim() == 0) {
117+
memcpy(out.mutable_data_ptr(), in.const_data_ptr(), out.nbytes());
118+
return out;
119+
}
120+
121+
size_t leading_dims = getLeadingDims(in, dim);
122+
size_t trailing_dims = getTrailingDims(in, dim);
123+
124+
if (leading_dims == 0 || trailing_dims == 0) {
125+
return out;
126+
}
127+
128+
size_t in_dim_length = in.size(dim);
129+
size_t out_dim_length = out.size(dim);
130+
131+
size_t length_per_step = trailing_dims * in.element_size();
132+
133+
const char* in_data = in.const_data_ptr<char>();
134+
char* out_data = out.mutable_data_ptr<char>();
135+
136+
// @lint-ignore CLANGTIDY facebook-hte-CArray
137+
static constexpr const char op_name[] = "index.Tensor_out";
138+
139+
ET_SWITCH_TWO_TYPES(Long, Int, index_type, ctx, op_name, CTYPE, [&]() {
140+
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
141+
for (const auto i : c10::irange(leading_dims)) {
142+
const char* src = in_data + i * in_dim_length * length_per_step;
143+
char* dest = out_data + i * out_dim_length * length_per_step;
144+
for (const auto j : c10::irange(out_dim_length)) {
145+
const char* copy_src = src + index_arr[j] * length_per_step;
146+
char* copy_dest = dest + j * length_per_step;
147+
memcpy(copy_dest, copy_src, length_per_step);
148+
}
149+
}
150+
});
151+
152+
return out;
153+
}
154+
155+
} // namespace
156+
157+
Tensor& index_Tensor_out(
158+
KernelRuntimeContext& ctx,
159+
const Tensor& in,
160+
TensorOptList indices,
161+
Tensor& out) {
162+
(void)ctx;
34163

35164
ET_KERNEL_CHECK(
36165
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
37166

38167
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
39168

169+
size_t dim = 0;
170+
bool is_fast_path = check_fast_path_conditions(in, indices, &dim);
171+
if (is_fast_path) {
172+
return fast_path(ctx, in, indices, dim, out);
173+
}
174+
175+
ET_KERNEL_CHECK(
176+
ctx, check_index_args(in, indices, out), InvalidArgument, out);
177+
40178
ScalarType in_type = in.scalar_type();
41179
size_t block_count = count_index_blocks(indices);
42180

kernels/test/op_index_test.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,3 +627,61 @@ TEST_F(OpIndexTensorOutTest, UpperBoundOutTensor) {
627627
EXPECT_TENSOR_EQ(out, ret);
628628
EXPECT_TENSOR_EQ(ret, expected);
629629
}
630+
631+
TEST_F(OpIndexTensorOutTest, FastPath) {
632+
TensorFactory<ScalarType::Float> tf;
633+
TensorFactory<ScalarType::Long> tfl;
634+
635+
// clang-format off
636+
Tensor x = tf.make(
637+
{2, 3, 4},
638+
{
639+
// [0, :, :]
640+
1., 2., 3., 4., // [0, 0, :]
641+
5., 6., 7., 8., // [0, 1, :]
642+
9., 10., 11., 12., // [0, 2, :]
643+
644+
// [1, :, :]
645+
-1., -2., -3., -4., // [1, 0, :]
646+
-5., -6., -7., -8., // [1, 1, :]
647+
-9., -10., -11., -12., // [1, 2, :]
648+
});
649+
// clang-format on
650+
651+
optional<Tensor> indices[] = {
652+
optional<Tensor>(),
653+
optional<Tensor>(),
654+
optional<Tensor>(tfl.make({3}, {2, 0, 1}))};
655+
656+
Tensor out = tf.zeros({2, 3, 3});
657+
// clang-format off
658+
Tensor expected = tf.make(
659+
{2, 3, 3},
660+
{
661+
3., 1., 2.,
662+
7., 5., 6.,
663+
11., 9., 10.,
664+
665+
-3., -1., -2.,
666+
-7., -5., -6.,
667+
-11., -9., -10.,
668+
});
669+
// clang-format on
670+
671+
op_index_tensor_out(x, indices, out);
672+
673+
EXPECT_TENSOR_EQ(out, expected);
674+
}
675+
676+
TEST_F(OpIndexTensorOutTest, FastPathZeroDim) {
677+
TensorFactory<ScalarType::Float> tf;
678+
TensorFactory<ScalarType::Long> tfl;
679+
680+
Tensor x = tf.ones({0});
681+
optional<Tensor> indices[] = {optional<Tensor>(tfl.zeros({0}))};
682+
Tensor out = tf.zeros({0});
683+
Tensor expected = tf.ones({0});
684+
op_index_tensor_out(x, indices, out);
685+
686+
EXPECT_TENSOR_EQ(out, expected);
687+
}

0 commit comments

Comments
 (0)