Skip to content

Commit 874d3c1

Browse files
Optimize index_out via fast path
Differential Revision: D81142086 Pull Request resolved: #13731
1 parent 932818c commit 874d3c1

File tree

2 files changed

+451
-22
lines changed

2 files changed

+451
-22
lines changed

kernels/portable/cpu/op_index.cpp

Lines changed: 171 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,189 @@ namespace native {
2222
using Tensor = executorch::aten::Tensor;
2323
using TensorOptList = executorch::aten::ArrayRef<std::optional<Tensor>>;
2424

25-
Tensor& index_Tensor_out(
25+
namespace {
26+
27+
bool check_fast_path_conditions(
28+
ET_UNUSED const Tensor& in,
29+
TensorOptList indices,
30+
size_t* dim) {
31+
bool found_index = false;
32+
for (const auto i : c10::irange(indices.size())) {
33+
if (indices[i].has_value()) {
34+
*dim = i;
35+
// Fast path only supports a single non-null index tensor
36+
if (found_index) {
37+
return false;
38+
}
39+
found_index = true;
40+
const Tensor& index = indices[i].value();
41+
ScalarType ix_type = index.scalar_type();
42+
// Fast path only supports Long or Int index tensors
43+
if (ix_type != ScalarType::Long && ix_type != ScalarType::Int) {
44+
return false;
45+
}
46+
// Fast path only supports a 1-dimensional index tensor
47+
if (index.dim() != 1) {
48+
return false;
49+
}
50+
}
51+
}
52+
53+
// Fast path needs at least one non-null index tensor
54+
if (!found_index) {
55+
return false;
56+
}
57+
58+
return true;
59+
}
60+
61+
bool check_fast_path_args(
2662
KernelRuntimeContext& ctx,
2763
const Tensor& in,
2864
TensorOptList indices,
65+
size_t dim,
2966
Tensor& out) {
30-
(void)ctx;
67+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
68+
69+
ET_CHECK_OR_RETURN_FALSE(
70+
static_cast<ssize_t>(indices.size()) <= in.dim(),
71+
"Indexing too many dimensions");
72+
73+
const Tensor& index = indices[dim].value();
74+
75+
bool is_valid_index = true;
76+
ET_SWITCH_TWO_TYPES(
77+
Long, Int, index.scalar_type(), ctx, "index.Tensor", CTYPE, [&]() {
78+
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
79+
for (const auto i : c10::irange(index.numel())) {
80+
if (index_arr[i] < 0 ||
81+
index_arr[i] >= static_cast<CTYPE>(in.size(dim))) {
82+
ET_LOG(
83+
Error,
84+
"Index %" PRId64
85+
" out of range for tensor with size %zd"
86+
" at dimension %zu",
87+
static_cast<int64_t>(index_arr[i]),
88+
in.size(dim),
89+
dim);
90+
is_valid_index = false;
91+
break;
92+
}
93+
}
94+
});
95+
96+
ET_CHECK_OR_RETURN_FALSE(
97+
is_valid_index,
98+
"Some index values are not within bounds of input tensor at indexed dim");
3199

100+
return true;
101+
}
102+
103+
void get_fast_path_index_out_target_size(
104+
const Tensor& in,
105+
TensorOptList indices,
106+
size_t dim,
107+
Tensor::SizesType* out_sizes,
108+
size_t* out_ndim) {
109+
*out_ndim = in.dim();
110+
111+
for (const auto d : c10::irange(static_cast<size_t>(in.dim()))) {
112+
if (d != dim) {
113+
out_sizes[d] = static_cast<Tensor::SizesType>(in.size(d));
114+
} else {
115+
out_sizes[d] =
116+
static_cast<Tensor::SizesType>(indices[dim].value().numel());
117+
}
118+
}
119+
}
120+
121+
Tensor& fast_path(
122+
KernelRuntimeContext& ctx,
123+
const Tensor& in,
124+
TensorOptList indices,
125+
size_t dim,
126+
Tensor& out) {
32127
ET_KERNEL_CHECK(
33-
ctx, check_index_args(in, indices, out), InvalidArgument, out);
128+
ctx,
129+
check_fast_path_args(ctx, in, indices, dim, out),
130+
InvalidArgument,
131+
out);
132+
133+
const Tensor& index = indices[dim].value();
134+
ScalarType index_type = index.scalar_type();
135+
136+
// @lint-ignore CLANGTIDY facebook-hte-CArray
137+
Tensor::SizesType expected_size[kTensorDimensionLimit];
138+
size_t expected_ndim = 0;
139+
get_fast_path_index_out_target_size(
140+
in, indices, dim, expected_size, &expected_ndim);
34141

142+
ET_KERNEL_CHECK(
143+
ctx,
144+
resize_tensor(out, {expected_size, expected_ndim}) == Error::Ok,
145+
InvalidArgument,
146+
out);
147+
148+
if (out.dim() == 0) {
149+
memcpy(out.mutable_data_ptr(), in.const_data_ptr(), out.nbytes());
150+
return out;
151+
}
152+
153+
size_t leading_dims = getLeadingDims(in, dim);
154+
size_t trailing_dims = getTrailingDims(in, dim);
155+
156+
if (leading_dims == 0 || trailing_dims == 0) {
157+
return out;
158+
}
159+
160+
size_t in_dim_length = in.size(dim);
161+
size_t out_dim_length = out.size(dim);
162+
163+
size_t length_per_step = trailing_dims * in.element_size();
164+
165+
const char* in_data = in.const_data_ptr<char>();
166+
char* out_data = out.mutable_data_ptr<char>();
167+
168+
// @lint-ignore CLANGTIDY facebook-hte-CArray
169+
static constexpr const char op_name[] = "index.Tensor_out";
170+
171+
ET_SWITCH_TWO_TYPES(Long, Int, index_type, ctx, op_name, CTYPE, [&]() {
172+
const CTYPE* const index_arr = index.const_data_ptr<CTYPE>();
173+
for (const auto i : c10::irange(leading_dims)) {
174+
const char* src = in_data + i * in_dim_length * length_per_step;
175+
char* dest = out_data + i * out_dim_length * length_per_step;
176+
for (const auto j : c10::irange(out_dim_length)) {
177+
const char* copy_src = src + index_arr[j] * length_per_step;
178+
char* copy_dest = dest + j * length_per_step;
179+
memcpy(copy_dest, copy_src, length_per_step);
180+
}
181+
}
182+
});
183+
184+
return out;
185+
}
186+
187+
} // namespace
188+
189+
Tensor& index_Tensor_out(
190+
KernelRuntimeContext& ctx,
191+
const Tensor& in,
192+
TensorOptList indices,
193+
Tensor& out) {
35194
ET_KERNEL_CHECK(
36195
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
37196

38197
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
39198

199+
size_t dim = 0;
200+
bool is_fast_path = check_fast_path_conditions(in, indices, &dim);
201+
if (is_fast_path) {
202+
return fast_path(ctx, in, indices, dim, out);
203+
}
204+
205+
ET_KERNEL_CHECK(
206+
ctx, check_index_args(in, indices, out), InvalidArgument, out);
207+
40208
ScalarType in_type = in.scalar_type();
41209
size_t block_count = count_index_blocks(indices);
42210

0 commit comments

Comments
 (0)