Skip to content

Commit 324864d

Browse files
Add op: topk
Differential Revision: D59936967 Pull Request resolved: #4307
1 parent a4092c5 commit 324864d

File tree

5 files changed

+398
-0
lines changed

5 files changed

+398
-0
lines changed

kernels/portable/cpu/op_topk.cpp

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cmath>
10+
#include <tuple>
11+
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
14+
namespace torch {
15+
namespace executor {
16+
namespace native {
17+
namespace {
18+
19+
bool check_topk_args(
20+
const Tensor& in,
21+
int64_t k,
22+
int64_t dim,
23+
Tensor& values,
24+
Tensor& indices) {
25+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, values));
26+
ET_LOG_AND_RETURN_IF_FALSE(indices.scalar_type() == ScalarType::Long);
27+
ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
28+
if (dim < 0) {
29+
dim += nonzero_dim(in);
30+
}
31+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
32+
k >= 0 && k <= nonempty_size(in, dim), "selected index k out of range");
33+
return true;
34+
}
35+
36+
bool get_topk_target_size(
37+
const Tensor& in,
38+
int64_t k,
39+
int64_t dim,
40+
Tensor::SizesType* target_size,
41+
size_t* target_dim) {
42+
*target_dim = in.dim();
43+
for (size_t i = 0; i < *target_dim; ++i) {
44+
if (i == dim) {
45+
target_size[i] = k;
46+
} else {
47+
target_size[i] = in.size(i);
48+
}
49+
}
50+
return true;
51+
}
52+
53+
template <typename CTYPE, typename elem_t = std::pair<CTYPE, int64_t>>
54+
void perform_topk(
55+
const Tensor& in,
56+
int64_t k,
57+
int64_t dim,
58+
bool largest,
59+
bool sorted,
60+
Tensor& values,
61+
Tensor& indices,
62+
elem_t* queue) {
63+
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
64+
CTYPE* values_data = values.mutable_data_ptr<CTYPE>();
65+
long* indices_data = indices.mutable_data_ptr<long>();
66+
67+
if (in.dim() == 0) {
68+
values_data[0] = in_data[0];
69+
indices_data[0] = 0;
70+
return;
71+
}
72+
73+
if (k == 0) {
74+
return;
75+
}
76+
77+
const size_t outer_size = getLeadingDims(in, dim);
78+
79+
const size_t dim_size = in.size(dim);
80+
const size_t dim_stride = in.strides()[dim];
81+
82+
const size_t outer_stride_in = dim_size * dim_stride;
83+
const size_t outer_stride_out = k * dim_stride;
84+
85+
bool use_partial_sort = k * 64 <= dim_size;
86+
87+
// Loop through all outer dimensions
88+
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
89+
size_t outer_in = outer_idx * outer_stride_in;
90+
size_t outer_out = outer_idx * outer_stride_out;
91+
// Loop through all inner dimensions
92+
for (size_t inner_idx = 0; inner_idx < dim_stride; ++inner_idx) {
93+
size_t base_in = outer_in + inner_idx;
94+
size_t base_out = outer_out + inner_idx;
95+
96+
// Populate the queue with the values from the input tensor
97+
for (size_t i = 0; i < dim_size; ++i) {
98+
size_t in_ix = base_in + i * dim_stride;
99+
queue[i].first = in_data[in_ix];
100+
queue[i].second = i;
101+
}
102+
103+
// Perform topk on the queue
104+
if (use_partial_sort) {
105+
if (largest) {
106+
std::partial_sort(
107+
queue,
108+
queue + k,
109+
queue + dim_size,
110+
[](const elem_t& x, const elem_t& y) -> bool {
111+
return (
112+
(std::isnan(x.first) && !std::isnan(y.first)) ||
113+
(x.first > y.first));
114+
});
115+
} else {
116+
std::partial_sort(
117+
queue,
118+
queue + k,
119+
queue + dim_size,
120+
[](const elem_t& x, const elem_t& y) -> bool {
121+
return (
122+
(!std::isnan(x.first) && std::isnan(y.first)) ||
123+
(x.first < y.first));
124+
});
125+
}
126+
} else {
127+
if (largest) {
128+
std::nth_element(
129+
queue,
130+
queue + k - 1,
131+
queue + dim_size,
132+
[](const elem_t& x, const elem_t& y) -> bool {
133+
return (
134+
(std::isnan(x.first) && !std::isnan(y.first)) ||
135+
(x.first > y.first));
136+
});
137+
if (sorted) {
138+
std::sort(
139+
queue,
140+
queue + k - 1,
141+
[](const elem_t& x, const elem_t& y) -> bool {
142+
return (
143+
(std::isnan(x.first) && !std::isnan(y.first)) ||
144+
(x.first > y.first));
145+
});
146+
}
147+
} else {
148+
std::nth_element(
149+
queue,
150+
queue + k - 1,
151+
queue + dim_size,
152+
[](const elem_t& x, const elem_t& y) -> bool {
153+
return (
154+
(!std::isnan(x.first) && std::isnan(y.first)) ||
155+
(x.first < y.first));
156+
});
157+
if (sorted) {
158+
std::sort(
159+
queue,
160+
queue + k - 1,
161+
[](const elem_t& x, const elem_t& y) -> bool {
162+
return (
163+
(!std::isnan(x.first) && std::isnan(y.first)) ||
164+
(x.first < y.first));
165+
});
166+
}
167+
}
168+
}
169+
170+
// Write the topk values and indices to the output tensors
171+
for (size_t i = 0; i < k; ++i) {
172+
size_t out_ix = base_out + i * dim_stride;
173+
174+
values_data[out_ix] = queue[i].first;
175+
indices_data[out_ix] = queue[i].second;
176+
}
177+
}
178+
}
179+
}
180+
181+
void* allocate_temp_memory(RuntimeContext& ctx, size_t size) {
182+
Result<void*> temp_mem_res = ctx.allocate_temp(size);
183+
return temp_mem_res.ok() ? temp_mem_res.get() : nullptr;
184+
}
185+
186+
} // namespace
187+
188+
std::tuple<Tensor&, Tensor&> topk_values(
189+
RuntimeContext& ctx,
190+
const Tensor& in,
191+
int64_t k,
192+
int64_t dim,
193+
bool largest,
194+
bool sorted,
195+
Tensor& values,
196+
Tensor& indices) {
197+
auto out = std::tuple<Tensor&, Tensor&>({values, indices});
198+
199+
ET_KERNEL_CHECK(
200+
ctx, check_topk_args(in, k, dim, values, indices), InvalidArgument, out);
201+
202+
if (dim < 0) {
203+
dim += nonzero_dim(in);
204+
}
205+
206+
// @lint-ignore CLANGTIDY facebook-hte-CArray
207+
Tensor::SizesType target_size[kTensorDimensionLimit];
208+
size_t target_dim = 0;
209+
get_topk_target_size(in, k, dim, target_size, &target_dim);
210+
211+
ET_KERNEL_CHECK(
212+
ctx,
213+
resize_tensor(values, {target_size, target_dim}) == Error::Ok,
214+
InvalidArgument,
215+
out);
216+
217+
ET_KERNEL_CHECK(
218+
ctx,
219+
resize_tensor(indices, {target_size, target_dim}) == Error::Ok,
220+
InvalidArgument,
221+
out);
222+
223+
constexpr auto name = "topk.values";
224+
225+
if (in.numel() == 0 || (k == 0 && in.dim() > 0)) {
226+
return out;
227+
}
228+
229+
bool temp_mem_allocated = false;
230+
231+
ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
232+
using elem_t = std::pair<CTYPE, int64_t>;
233+
size_t temp_mem_size = nonempty_size(in, dim) * sizeof(elem_t);
234+
235+
elem_t* queue = (elem_t*)allocate_temp_memory(ctx, temp_mem_size);
236+
if (queue == nullptr) {
237+
return;
238+
}
239+
temp_mem_allocated = true;
240+
241+
perform_topk<CTYPE>(in, k, dim, largest, sorted, values, indices, queue);
242+
});
243+
244+
ET_KERNEL_CHECK(ctx, temp_mem_allocated, MemoryAllocationFailed, out);
245+
246+
return out;
247+
}
248+
249+
} // namespace native
250+
} // namespace executor
251+
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,11 @@
847847
- arg_meta: null
848848
kernel_name: torch::executor::tanh_out
849849

850+
- op: topk.values
851+
kernels:
852+
- arg_meta: null
853+
kernel_name: torch::executor::topk_values
854+
850855
- op: transpose_copy.int_out
851856
kernels:
852857
- arg_meta: null

0 commit comments

Comments
 (0)