[rocPRIM] Add device level parallel algorithm TopK and its tests and benchmarks#3646
[rocPRIM] Add device level parallel algorithm TopK and its tests and benchmarks#3646cenxuantian wants to merge 12 commits intoROCm:developfrom
Conversation
633c6ec to
fae46f2
Compare
fae46f2 to
9c213e9
Compare
|
Hi @cenxuantian, in the CI checks, we're getting a segmentation fault on gfx950 and gfx942 in rocprim.block_exchange. I'm able to reproduce the problem with this PR, but not without it (using the latest from develop). Maybe a rebase will fix this up. |
Thanks, I will check that! |
|
Hi @umfranzw |
b3d7b06 to
c623e9a
Compare
|
Hi @umfranzw , It's now ready for review. All necessary features related to the topk are included in this PR. The next step for topk is to:
I will also link this PR in jira. If there are any questions about the implementation, please let me know. Thanks! Best regards, |
umfranzw
left a comment
There was a problem hiding this comment.
I think this looks good. I just have one quick question.
63bb991 to
54e520e
Compare
|
Hi @cenxuantian, I took a look into the Windows CI build failure. There's a known issue on Windows where GTest can't print 128-bit values. Because of this, we can't call The problem is that some of the topK test pass template <class T, class U, bool UseGTestAssert = is_printable<T> && is_printable<U>>
void inline protected_assert_eq(std::pair<T, U> val, std::pair<T, U> expected, size_t index)
{
if constexpr (UseGTestAssert)
{
ASSERT_EQ(val, expected) << "where index = " << index;
}
else
{
const bool result = (val == expected);
ASSERT_TRUE(result) << "where index = " << index;
}
} |
914eff5 to
e84b019
Compare
Hi @umfranzw, Thanks, I changed it to use |
9a36415 to
aafe281
Compare
umfranzw
left a comment
There was a problem hiding this comment.
This looks good to me, as soon as it passes CI.
|
Hi @cenxuantian, when you have a chance, would you mind taking a look at the merge conflicts? Once they're sorted, I'll rerun CI and hopefully we can finally merge this. |
aafe281 to
3a796b0
Compare
Hi @umfranzw, the merge conflicts are resolved |
|
Hi @umfranzw, I pushed 1 commit with a few lines of changes as a temporary workaround for the known issue of We also finished tuning for Kind regards, |
82bc939 to
9bd2d6f
Compare
53ca367 to
ea76745
Compare
|
Hi @umfranzw, I rebased this PR, and it's ready to for review again. Because "primbench" was merged, I had to update the benchmarks for topk to use the new primbench instead of google benchmark. So I added 1 more commit, which updates the benchmarks and also adds tuned configs for this algorithm. If there are any CI failure or anything else, please let me know. Best regards |
448b6a3 to
122bbc1
Compare
Hi @cenxuantian, really sorry for the delay on this. The changes look good to me. We'll just need to figure out what's blocking the Windows build. |
…ssue of the dpp shifting
122bbc1 to
8aa38a7
Compare
|
|
||
| for(auto size : sizes) | ||
| { | ||
| hipStream_t stream = 0; // default |
There was a problem hiding this comment.
Hi @cenxuantian, I finally tracked down the root cause of the Windows CI error we're getting. The error message is:
[----------] 2 tests from RocprimDeviceTopkTests
[ RUN ] RocprimDeviceTopkTests.TopkLargeSizesStable
HIP error: unspecified launch failure file: C:/home/runner/_work/rocm-libraries/rocm-libraries/projects/rocprim/test/rocprim/test_device_topk.cpp line: 706
Start 49: rocprim.device_topk
It looks like this happens when gfx1151 runs out of memory on Windows.
Can we add a memory check inside the "size" for loop to continue in the case where we find we don't have enough memory?
I've done a few experiments, and hipMemGetInfo seems to be the best way to go about this because it returns the amount of currently available free memory. On APUs like gfx1151, it seems that the amount of available memory can sometimes be quite low because the memory system is shared with the host (which can use the memory for other things).
// Get the amount of currently available free memory
size_t free_mem;
size_t total_mem;
HIP_CHECK(hipMemGetInfo(&free_mem, &total_mem));
// This test needs to allocate 2 buffers of size: sizeof(key_type) * size, plus temporary storage.
// Since we have to allocate before computing the temp storage bytes (because the first call to topk
// requires the input and output pointers), maybe we can estimate the temp storage requirements like this?
if (2 * sizeof(key_type) * size > static_cast<size_t>(0.9 * free_mem))
{
std::cout << "Skipping test size - not enough available memory." << std::endl;
continue;
}
I haven't actually tried this yet, so I'm not sure the factor of 0.9 will work. I can give it a shot tomorrow on our local gfx1151 test system and get back to you. Or if you have any other suggestions, feel free to let me know.
Motivation
CUB added topk algorithm, so we also add this algorithm to our library.
Technical Details
In this PR, several items are addressed:
topkandtopk_pairwere added, following the principles below:device_radix_sort)rocprim::detail::device_topk_air_topkis implementedtopkwere addeddevice_topk_air_topkwere added (considering small/large k and natural/radix-adversarial distribution)device_topk_air_topk(The public APIs
rocprim::topkandrocprim::topk_pairdo not require tuning since they callradix_sortdirectly, and their performance is highly dependent onradix_sort.)topkAPI and the internaltopk_air_topkimplementation was addednth_element, additional documentation was added to clarify the differences between themTechnical details of
device_topk_air_topkThis internal API was implemented based on this paper, after researching several similar approaches. We adopted the core ideas of Adaptive, Iteration-fused, and Radix-based processing from the paper, and implemented the algorithm within rocPRIM’s configuration and dispatching system.
Some optimizations were added to improve performance, including compile-time iteration count calculation, batched thread stores, and automatic decay of the input size type. Moreover, there is no explicit CPU–GPU data communication between iterations, which allows this algorithm to support
hipGraph.The "A-I-R" acronym stands for:
radix_sort. It extracts digits from the most significant to the least significant digit, and distinguishes ascending and descending order by flipping digits (referred to asTwiddleIninhipcub).The overall logic can be illustrated by the flowchart below (excluding Adaptive and Iteration-fused features, as including them would make the diagram significantly more complex).

Test Plan
Tests for general topk and Airtopk are added
Test Result
Submission Checklist