Skip to content

Commit 633c6ec

Browse files
Beanavilcenxuantian
authored andcommitted
Resolve "Expose new top-k API and basic tests."
1 parent 1eeeba6 commit 633c6ec

File tree

6 files changed

+978
-32
lines changed

6 files changed

+978
-32
lines changed
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
// Copyright (c) 2025 Advanced Micro Devices, Inc. All rights reserved.
2+
//
3+
// Permission is hereby granted, free of charge, to any person obtaining a copy
4+
// of this software and associated documentation files (the "Software"), to deal
5+
// in the Software without restriction, including without limitation the rights
6+
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7+
// copies of the Software, and to permit persons to whom the Software is
8+
// furnished to do so, subject to the following conditions:
9+
//
10+
// The above copyright notice and this permission notice shall be included in
11+
// all copies or substantial portions of the Software.
12+
//
13+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15+
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16+
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17+
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18+
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19+
// THE SOFTWARE.
20+
21+
#ifndef ROCPRIM_DEVICE_DEVICE_TOPK_HPP_
22+
#define ROCPRIM_DEVICE_DEVICE_TOPK_HPP_
23+
24+
#include "../detail/temp_storage.hpp"
25+
26+
#include "device_merge_sort.hpp"
27+
#include "device_radix_sort.hpp"
28+
#include "device_transform.hpp"
29+
30+
#include <iterator>
31+
#include <type_traits>
32+
33+
BEGIN_ROCPRIM_NAMESPACE
34+
35+
/// \addtogroup devicemodule
36+
/// @{
37+
38+
namespace detail
39+
{
40+
41+
template<typename KeysInputIterator, typename BinaryFunction, typename Decomposer>
42+
struct radix_topk_condition_checker
43+
{
44+
using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
45+
46+
static constexpr bool is_custom_decomposer
47+
= !std::is_same<Decomposer, rocprim::identity_decomposer>::value;
48+
static constexpr bool descending
49+
= std::is_same<BinaryFunction, rocprim::greater<key_type>>::value
50+
|| std::is_same<BinaryFunction, rocprim::greater<void>>::value;
51+
static constexpr bool ascending = std::is_same<BinaryFunction, rocprim::less<key_type>>::value
52+
|| std::is_same<BinaryFunction, rocprim::less<void>>::value;
53+
static constexpr bool is_radix_key_fundamental
54+
= rocprim::traits::radix_key_codec::radix_key_fundamental<key_type>::value;
55+
static constexpr bool use_radix
56+
= (is_radix_key_fundamental || is_custom_decomposer) && (descending || ascending);
57+
};
58+
59+
// Primary template for TopKImpl, assumes default topk_impl_algorithm
60+
template<bool UseRadix,
61+
class config,
62+
bool Ordered,
63+
bool Deterministic,
64+
bool Stable,
65+
class KeysInputIterator,
66+
class KeysOutputIterator,
67+
class ValuesInputIterator,
68+
class ValuesOutputIterator,
69+
class SizeIn,
70+
class SizeOut,
71+
class BinaryFunction,
72+
class Decomposer>
73+
struct TopKImpl
74+
{
75+
static ROCPRIM_INLINE
76+
hipError_t algo_impl(void* temporary_storage,
77+
size_t& storage_size,
78+
const KeysInputIterator keys_input,
79+
const KeysOutputIterator keys_output,
80+
const ValuesInputIterator values_input,
81+
const ValuesOutputIterator values_output,
82+
const SizeIn size,
83+
const SizeOut k,
84+
const hipStream_t stream,
85+
const bool debug_synchronous,
86+
const BinaryFunction /*compare_function*/,
87+
const Decomposer decomposer = {})
88+
{
89+
// Default is radix_topk, check we can actually use it
90+
using radix_checker
91+
= radix_topk_condition_checker<KeysInputIterator, BinaryFunction, Decomposer>;
92+
static_assert(UseRadix && radix_checker::use_radix,
93+
"Parameters for TopK implementation RadixTopK are not valid!");
94+
95+
// Check implementation properties
96+
static_assert(!radix_checker::is_custom_decomposer,
97+
"RadixTopK does not support custom keys");
98+
static_assert(Ordered == false, "Radix TopK does not support ordered output");
99+
static_assert(Deterministic == false, "Radix TopK does not support determinism");
100+
static_assert(Stable == false, "Radix TopK does not support stability");
101+
102+
bool ignored;
103+
(void)k;
104+
return detail::radix_sort_impl<config, radix_checker::descending>(
105+
temporary_storage,
106+
storage_size,
107+
keys_input,
108+
nullptr,
109+
keys_output,
110+
values_input,
111+
nullptr,
112+
values_output,
113+
size,
114+
ignored,
115+
decomposer,
116+
0,
117+
sizeof(typename std::iterator_traits<KeysInputIterator>::value_type) * 8,
118+
stream,
119+
false,
120+
debug_synchronous);
121+
}
122+
};
123+
124+
template<bool UseRadix,
125+
class Config,
126+
bool Ordered,
127+
bool Deterministic,
128+
bool Stable,
129+
class KeysInputIterator,
130+
class KeysOutputIterator,
131+
class ValuesInputIterator,
132+
class ValuesOutputIterator,
133+
class SizeIn,
134+
class SizeOut,
135+
class BinaryFunction,
136+
class Decomposer>
137+
ROCPRIM_INLINE
138+
hipError_t topk_impl(void* temporary_storage,
139+
size_t& storage_size,
140+
const KeysInputIterator keys_input,
141+
const KeysOutputIterator keys_output,
142+
const ValuesInputIterator values_input,
143+
const ValuesOutputIterator values_output,
144+
const SizeIn size,
145+
SizeOut k,
146+
const BinaryFunction compare_function = BinaryFunction(),
147+
const Decomposer decomposer = {},
148+
const hipStream_t stream = 0,
149+
const bool debug_synchronous = false)
150+
{
151+
using key_type = typename std::iterator_traits<KeysInputIterator>::value_type;
152+
using value_type = typename std::iterator_traits<ValuesInputIterator>::value_type;
153+
using common_size_t = typename std::common_type<decltype(size), decltype(k)>::type;
154+
static_assert(std::is_integral<common_size_t>::value, "Size and k must be integral types.");
155+
static_assert(
156+
std::is_same<key_type,
157+
typename std::iterator_traits<KeysOutputIterator>::value_type>::value,
158+
"KeysInputIterator and KeysOutputIterator must have the same value_type");
159+
static_assert(
160+
std::is_same<value_type,
161+
typename std::iterator_traits<ValuesOutputIterator>::value_type>::value,
162+
"ValuesInputIterator and ValuesOutputIterator must have the same value_type");
163+
164+
// Limit k to size
165+
if(k < 0)
166+
{
167+
return hipErrorInvalidValue;
168+
}
169+
k = static_cast<SizeOut>(std::min(common_size_t{k}, static_cast<common_size_t>(size)));
170+
171+
if(temporary_storage == nullptr)
172+
{
173+
return detail::TopKImpl<UseRadix,
174+
Config,
175+
Ordered,
176+
Deterministic,
177+
Stable,
178+
KeysInputIterator,
179+
KeysOutputIterator,
180+
ValuesInputIterator,
181+
ValuesOutputIterator,
182+
SizeIn,
183+
SizeOut,
184+
BinaryFunction,
185+
Decomposer>::algo_impl(temporary_storage,
186+
storage_size,
187+
keys_input,
188+
keys_output,
189+
values_input,
190+
values_output,
191+
size,
192+
k,
193+
stream,
194+
debug_synchronous,
195+
compare_function,
196+
decomposer);
197+
}
198+
199+
// Start point for time measurements
200+
std::chrono::steady_clock::time_point start;
201+
if(debug_synchronous)
202+
{
203+
start = std::chrono::steady_clock::now();
204+
}
205+
206+
ROCPRIM_RETURN_ON_ERROR(detail::TopKImpl<UseRadix,
207+
Config,
208+
Ordered,
209+
Deterministic,
210+
Stable,
211+
KeysInputIterator,
212+
KeysOutputIterator,
213+
ValuesInputIterator,
214+
ValuesOutputIterator,
215+
SizeIn,
216+
SizeOut,
217+
BinaryFunction,
218+
Decomposer>::algo_impl(temporary_storage,
219+
storage_size,
220+
keys_input,
221+
keys_input,
222+
values_input,
223+
values_input,
224+
size,
225+
k,
226+
stream,
227+
debug_synchronous,
228+
compare_function,
229+
decomposer));
230+
ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("TopKImpl::algo_impl", size, start);
231+
232+
ROCPRIM_RETURN_ON_ERROR(
233+
transform(keys_input, keys_output, k, ::rocprim::identity<>(), stream, debug_synchronous));
234+
static constexpr bool with_values = !std::is_same<value_type, rocprim::empty_type>::value;
235+
if constexpr(with_values)
236+
{
237+
ROCPRIM_RETURN_ON_ERROR(transform(values_input,
238+
values_output,
239+
k,
240+
::rocprim::identity<>(),
241+
stream,
242+
debug_synchronous));
243+
}
244+
return hipSuccess;
245+
}
246+
247+
} // namespace detail
248+
249+
/// \brief Find the largest/smallest K elements from an input array of keys.
250+
///
251+
/// The K elements are returned within the K first positions of the output array and in a non specific order.
252+
template<class Config = default_config,
253+
bool Descending = false,
254+
bool Ordered = false,
255+
bool Deterministic = false,
256+
bool Stable = false,
257+
class Decomposer = ::rocprim::identity_decomposer,
258+
class KeysInputIterator,
259+
class KeysOutputIterator,
260+
class SizeIn,
261+
class SizeOut>
262+
ROCPRIM_INLINE
263+
hipError_t topk(void* temporary_storage,
264+
size_t& storage_size,
265+
const KeysInputIterator keys_input,
266+
const KeysOutputIterator keys_output,
267+
const SizeIn size,
268+
const SizeOut k,
269+
Decomposer decomposer = {},
270+
const hipStream_t stream = 0,
271+
const bool debug_synchronous = false)
272+
{
273+
using compare_function = std::conditional_t<
274+
Descending,
275+
rocprim::greater<typename std::iterator_traits<KeysInputIterator>::value_type>,
276+
rocprim::less<typename std::iterator_traits<KeysInputIterator>::value_type>>;
277+
return detail::topk_impl<true, Config, Ordered, Deterministic, Stable>(
278+
temporary_storage,
279+
storage_size,
280+
keys_input,
281+
keys_output,
282+
static_cast<empty_type*>(nullptr),
283+
static_cast<empty_type*>(nullptr),
284+
size,
285+
k,
286+
compare_function(),
287+
decomposer,
288+
stream,
289+
debug_synchronous);
290+
}
291+
292+
/// \brief Find the largest/smallest K elements from an input array of values based on their correspondent keys.
293+
///
294+
/// The K pairs (key, value) are returned within the K first positions of the output keys and values arrays,
295+
/// and in a non specific order.
296+
template<class Config = default_config,
297+
bool Descending = false,
298+
bool Ordered = false,
299+
bool Deterministic = false,
300+
bool Stable = false,
301+
class Decomposer = rocprim::identity_decomposer,
302+
class KeysInputIterator,
303+
class KeysOutputIterator,
304+
class ValuesInputIterator,
305+
class ValuesOutputIterator,
306+
class SizeIn,
307+
class SizeOut>
308+
ROCPRIM_INLINE
309+
hipError_t topk_pairs(void* temporary_storage,
310+
size_t& storage_size,
311+
const KeysInputIterator keys_input,
312+
const KeysOutputIterator keys_output,
313+
const ValuesInputIterator values_input,
314+
const ValuesOutputIterator values_output,
315+
const SizeIn size,
316+
const SizeOut k,
317+
const Decomposer decomposer = {},
318+
const hipStream_t stream = 0,
319+
const bool debug_synchronous = false)
320+
{
321+
using compare_function = std::conditional_t<
322+
Descending,
323+
rocprim::greater<typename std::iterator_traits<KeysInputIterator>::value_type>,
324+
rocprim::less<typename std::iterator_traits<KeysInputIterator>::value_type>>;
325+
return detail::topk_impl<true, Config, Ordered, Deterministic, Stable>(temporary_storage,
326+
storage_size,
327+
keys_input,
328+
keys_output,
329+
values_input,
330+
values_output,
331+
size,
332+
k,
333+
compare_function(),
334+
decomposer,
335+
stream,
336+
debug_synchronous);
337+
}
338+
339+
END_ROCPRIM_NAMESPACE
340+
341+
/// @}
342+
// end of group devicemodule
343+
344+
#endif // ROCPRIM_DEVICE_DEVICE_TOPK_HPP_

projects/rocprim/rocprim/include/rocprim/rocprim.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
#include "device/device_segmented_reduce.hpp"
8383
#include "device/device_segmented_scan.hpp"
8484
#include "device/device_select.hpp"
85+
#include "device/device_topk.hpp"
8586
#include "device/device_transform.hpp"
8687

8788
/// \brief The top level rocPRIM namespace.

projects/rocprim/test/rocprim/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ add_rocprim_test("rocprim.device_histogram" test_device_histogram.cpp)
299299
add_rocprim_test("rocprim.device_merge" test_device_merge.cpp)
300300
add_rocprim_test("rocprim.device_merge_inplace" test_device_merge_inplace.cpp)
301301
add_rocprim_test("rocprim.device_merge_sort" test_device_merge_sort.cpp)
302-
add_rocprim_test("rocprim.nth_element" test_device_nth_element.cpp)
302+
add_rocprim_test("rocprim.device_nth_element" test_device_nth_element.cpp)
303303
add_rocprim_test("rocprim.device_partial_sort" test_device_partial_sort.cpp)
304304
add_rocprim_test("rocprim.device_partition" test_device_partition.cpp)
305305
add_rocprim_test_parallel("rocprim.device_radix_sort" test_device_radix_sort.cpp.in)
@@ -314,6 +314,7 @@ add_rocprim_test("rocprim.device_search_n" test_device_search_n.cpp)
314314
add_rocprim_test("rocprim.device_segmented_reduce" test_device_segmented_reduce.cpp)
315315
add_rocprim_test("rocprim.device_segmented_scan" test_device_segmented_scan.cpp)
316316
add_rocprim_test("rocprim.device_select" test_device_select.cpp)
317+
add_rocprim_test("rocprim.device_topk_radix" test_device_topk_radix.cpp)
317318
add_rocprim_test("rocprim.device_transform" test_device_transform.cpp)
318319
add_rocprim_test("rocprim.discard_iterator" test_discard_iterator.cpp)
319320
add_rocprim_test("rocprim.lookback_reproducibility" test_lookback_reproducibility.cpp)

0 commit comments

Comments
 (0)