Skip to content

[rocPRIM] Add device level parallel algorithm TopK and its tests and benchmarks#3646

Open
cenxuantian wants to merge 12 commits intoROCm:developfrom
StreamHPC:users/cenxuantian/topk-upstreaming
Open

[rocPRIM] Add device level parallel algorithm TopK and its tests and benchmarks#3646
cenxuantian wants to merge 12 commits intoROCm:developfrom
StreamHPC:users/cenxuantian/topk-upstreaming

Conversation

@cenxuantian
Copy link
Contributor

@cenxuantian cenxuantian commented Jan 6, 2026

Motivation

CUB added topk algorithm, so we also add this algorithm to our library.

Technical Details

In this PR, several items are addressed:

  • The device-level parallel algorithm APIs topk and topk_pair were added, following the principles below:
    • Determinism: not guaranteed
    • Stability:
      • Stable: stable (the stable version directly calls device_radix_sort)
      • Unstable: an unstable parallel algorithm rocprim::detail::device_topk_air_topk is implemented
    • extend_k_by_ties: no
    • Ordered: unordered
    • In-place: out-of-place
  • Tests for both stable and unstable topk were added
  • Benchmarks for device_topk_air_topk were added (considering small/large k and natural/radix-adversarial distribution)
  • Configuration and tuning mechanisms were added to device_topk_air_topk
    (The public APIs rocprim::topk and rocprim::topk_pair do not require tuning since they call radix_sort directly, and their performance is highly dependent on radix_sort.)
  • Documentation for the public topk API and the internal topk_air_topk implementation was added
    • To test the documentation, some Doxygen failures and warnings were also fixed
    • Since this algorithm performs a similar operation to nth_element, additional documentation was added to clarify the differences between them
  • A tested example was added

Technical details of device_topk_air_topk

This internal API was implemented based on this paper, after researching several similar approaches. We adopted the core ideas of Adaptive, Iteration-fused, and Radix-based processing from the paper, and implemented the algorithm within rocPRIM’s configuration and dispatching system.

Some optimizations were added to improve performance, including compile-time iteration count calculation, batched thread stores, and automatic decay of the input size type. Moreover, there is no explicit CPU–GPU data communication between iterations, which allows this algorithm to support hipGraph.

The "A-I-R" acronym stands for:

  • Adaptive: The algorithm maintains a temporary buffer in shared memory to store the input when necessary. This is designed to handle radix-adversarial distributions. If the input is evenly distributed, this buffer is not used. The algorithm decides whether to enable this feature by analyzing the item distribution from the first histogram generated during the initial kernel iteration.
  • Iteration-fused: In each iteration, computing the histogram for the current iteration and filtering the results from the previous iteration can be performed simultaneously.
  • Radix-based: The algorithm performs comparisons using radix operations, but in a different manner than radix_sort. It extracts digits from the most significant to the least significant digit, and distinguishes ascending and descending order by flipping digits (referred to as TwiddleIn in hipcub).

The overall logic can be illustrated by the flowchart below (excluding Adaptive and Iteration-fused features, as including them would make the diagram significantly more complex).
image

Test Plan

Tests for general topk and Airtopk are added

Test Result

Submission Checklist

@cenxuantian cenxuantian requested a review from a team as a code owner January 6, 2026 09:52
@assistant-librarian assistant-librarian bot added the external contribution Code contribution from users community.. label Jan 6, 2026
@Naraenda Naraenda added organization: streamhpc contributors from streamhpc labels Jan 6, 2026
@cenxuantian cenxuantian changed the title Draft: Resolve "Expose new top-k API and basic tests." Draft: Add device level parallel algorithm TopK and its tests and benchmarks Jan 6, 2026
@cenxuantian cenxuantian force-pushed the users/cenxuantian/topk-upstreaming branch from 633c6ec to fae46f2 Compare January 8, 2026 09:50
@cenxuantian cenxuantian force-pushed the users/cenxuantian/topk-upstreaming branch from fae46f2 to 9c213e9 Compare January 20, 2026 12:32
@cenxuantian cenxuantian requested a review from a team as a code owner January 20, 2026 12:32
@umfranzw umfranzw changed the title Draft: Add device level parallel algorithm TopK and its tests and benchmarks [rocPRIM] Add device level parallel algorithm TopK and its tests and benchmarks Jan 21, 2026
@umfranzw
Copy link
Contributor

Hi @cenxuantian, in the CI checks, we're getting a segmentation fault on gfx950 and gfx942 in rocprim.block_exchange. I'm able to reproduce the problem with this PR, but not without it (using the latest from develop). Maybe a rebase will fix this up.

@cenxuantian
Copy link
Contributor Author

Hi @cenxuantian, in the CI checks, we're getting a segmentation fault on gfx950 and gfx942 in rocprim.block_exchange. I'm able to reproduce the problem with this PR, but not without it (using the latest from develop). Maybe a rebase will fix this up.

Thanks, I will check that!

@cenxuantian cenxuantian changed the title [rocPRIM] Add device level parallel algorithm TopK and its tests and benchmarks [Draft][rocPRIM] Add device level parallel algorithm TopK and its tests and benchmarks Jan 21, 2026
@cenxuantian
Copy link
Contributor Author

Hi @umfranzw
BTW, this MR is not yet ready for review, since we still waiting for 2 commits to be reviewed from our side. But it's totally ok to start reviewing it now. Just for your information that, I will push 1 or 2 commits in the next few days.

@cenxuantian cenxuantian force-pushed the users/cenxuantian/topk-upstreaming branch 4 times, most recently from b3d7b06 to c623e9a Compare January 22, 2026 11:42
@cenxuantian cenxuantian changed the title [Draft][rocPRIM] Add device level parallel algorithm TopK and its tests and benchmarks [rocPRIM] Add device level parallel algorithm TopK and its tests and benchmarks Jan 22, 2026
@cenxuantian
Copy link
Contributor Author

Hi @umfranzw ,

It's now ready for review. All necessary features related to the topk are included in this PR. The next step for topk is to:

  • Implement segemented_topk (CCCL3.1 / rocm8.1 or rocm 8.2)
  • Tune this algorithm (probably 8.1)
  • Explore more opportunities of optimization

I will also link this PR in jira. If there are any questions about the implementation, please let me know. Thanks!

Best regards,
Cenxuan

Copy link
Contributor

@umfranzw umfranzw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks good. I just have one quick question.

@cenxuantian cenxuantian force-pushed the users/cenxuantian/topk-upstreaming branch from 63bb991 to 54e520e Compare January 26, 2026 16:17
@umfranzw
Copy link
Contributor

Hi @cenxuantian, I took a look into the Windows CI build failure. There's a known issue on Windows where GTest can't print 128-bit values. Because of this, we can't call ASSERT_EQ of 128-bit values directly - instead we have to call test_utils::assert_eq, which contain a workaround for this problem. That workaround is implemented in the protected_assert_eq functions in test/rocprim/test_utils_assertions.hpp (which are called by test_utils::assert_eq).

The problem is that some of the topK test pass test_utils::assert_eq arguments of type std::pair, and there is no overload of protected_assert_eq that handles that. I found that if I add this overload, it resolves the problem:

template <class T, class U, bool UseGTestAssert = is_printable<T> && is_printable<U>>
void inline protected_assert_eq(std::pair<T, U> val, std::pair<T, U> expected, size_t index)
{
    if constexpr (UseGTestAssert)
    {
        ASSERT_EQ(val, expected) << "where index = " << index;
    }
    else
    {
        const bool result = (val == expected);
        ASSERT_TRUE(result) << "where index = " << index;
    }
}

@cenxuantian cenxuantian force-pushed the users/cenxuantian/topk-upstreaming branch 2 times, most recently from 914eff5 to e84b019 Compare January 27, 2026 12:40
@cenxuantian
Copy link
Contributor Author

Hi @cenxuantian, I took a look into the Windows CI build failure. There's a known issue on Windows where GTest can't print 128-bit values. Because of this, we can't call ASSERT_EQ of 128-bit values directly - instead we have to call test_utils::assert_eq, which contain a workaround for this problem. That workaround is implemented in the protected_assert_eq functions in test/rocprim/test_utils_assertions.hpp (which are called by test_utils::assert_eq).

The problem is that some of the topK test pass test_utils::assert_eq arguments of type std::pair, and there is no overload of protected_assert_eq that handles that. I found that if I add this overload, it resolves the problem:

template <class T, class U, bool UseGTestAssert = is_printable<T> && is_printable<U>>
void inline protected_assert_eq(std::pair<T, U> val, std::pair<T, U> expected, size_t index)
{
    if constexpr (UseGTestAssert)
    {
        ASSERT_EQ(val, expected) << "where index = " << index;
    }
    else
    {
        const bool result = (val == expected);
        ASSERT_TRUE(result) << "where index = " << index;
    }
}

Hi @umfranzw, Thanks, I changed it to use test_utils::assert_eq.

@cenxuantian cenxuantian force-pushed the users/cenxuantian/topk-upstreaming branch 2 times, most recently from 9a36415 to aafe281 Compare February 2, 2026 13:53
Copy link
Contributor

@umfranzw umfranzw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, as soon as it passes CI.

@umfranzw
Copy link
Contributor

Hi @cenxuantian, when you have a chance, would you mind taking a look at the merge conflicts? Once they're sorted, I'll rerun CI and hopefully we can finally merge this.

@cenxuantian cenxuantian force-pushed the users/cenxuantian/topk-upstreaming branch from aafe281 to 3a796b0 Compare February 18, 2026 14:11
@cenxuantian
Copy link
Contributor Author

Hi @cenxuantian, when you have a chance, would you mind taking a look at the merge conflicts? Once they're sorted, I'll rerun CI and hopefully we can finally merge this.

Hi @umfranzw, the merge conflicts are resolved

@cenxuantian
Copy link
Contributor Author

Hi @umfranzw,

I pushed 1 commit with a few lines of changes as a temporary workaround for the known issue of __syncthreads_or, which refuses to build with the -O0 flag. This workaround doesn't affect the performance except when the user builds it with the -O0 flag.

We also finished tuning for topk_air. I can either push those changes here in this PR, or create a separate PR after this PR is merged. I am wondering which way you prefer.

Kind regards,
Cenxuan

@cenxuantian cenxuantian force-pushed the users/cenxuantian/topk-upstreaming branch from 82bc939 to 9bd2d6f Compare February 19, 2026 10:50
@cenxuantian cenxuantian force-pushed the users/cenxuantian/topk-upstreaming branch 3 times, most recently from 53ca367 to ea76745 Compare March 11, 2026 10:28
@cenxuantian
Copy link
Contributor Author

Hi @umfranzw,

I rebased this PR, and it's ready to for review again. Because "primbench" was merged, I had to update the benchmarks for topk to use the new primbench instead of google benchmark. So I added 1 more commit, which updates the benchmarks and also adds tuned configs for this algorithm. If there are any CI failure or anything else, please let me know.

Best regards

@cenxuantian cenxuantian force-pushed the users/cenxuantian/topk-upstreaming branch 2 times, most recently from 448b6a3 to 122bbc1 Compare March 13, 2026 13:21
@umfranzw
Copy link
Contributor

Hi @umfranzw,

I rebased this PR, and it's ready to for review again. Because "primbench" was merged, I had to update the benchmarks for topk to use the new primbench instead of google benchmark. So I added 1 more commit, which updates the benchmarks and also adds tuned configs for this algorithm. If there are any CI failure or anything else, please let me know.

Best regards

Hi @cenxuantian, really sorry for the delay on this. The changes look good to me. We'll just need to figure out what's blocking the Windows build.

@cenxuantian cenxuantian force-pushed the users/cenxuantian/topk-upstreaming branch from 122bbc1 to 8aa38a7 Compare March 18, 2026 17:18

for(auto size : sizes)
{
hipStream_t stream = 0; // default
Copy link
Contributor

@umfranzw umfranzw Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @cenxuantian, I finally tracked down the root cause of the Windows CI error we're getting. The error message is:

[----------] 2 tests from RocprimDeviceTopkTests
[ RUN      ] RocprimDeviceTopkTests.TopkLargeSizesStable
HIP error: unspecified launch failure file: C:/home/runner/_work/rocm-libraries/rocm-libraries/projects/rocprim/test/rocprim/test_device_topk.cpp line: 706

      Start 49: rocprim.device_topk

It looks like this happens when gfx1151 runs out of memory on Windows.
Can we add a memory check inside the "size" for loop to continue in the case where we find we don't have enough memory?

I've done a few experiments, and hipMemGetInfo seems to be the best way to go about this because it returns the amount of currently available free memory. On APUs like gfx1151, it seems that the amount of available memory can sometimes be quite low because the memory system is shared with the host (which can use the memory for other things).

// Get the amount of currently available free memory
size_t free_mem;
size_t total_mem;
HIP_CHECK(hipMemGetInfo(&free_mem, &total_mem));

// This test needs to allocate 2 buffers of size: sizeof(key_type) * size, plus temporary storage.
// Since we have to allocate before computing the temp storage bytes (because the first call to topk 
// requires the input and output pointers), maybe we can estimate the temp storage requirements like this?
if (2 * sizeof(key_type) * size > static_cast<size_t>(0.9 * free_mem))
{
  std::cout << "Skipping test size - not enough available memory." << std::endl;
  continue;
}

I haven't actually tried this yet, so I'm not sure the factor of 0.9 will work. I can give it a shot tomorrow on our local gfx1151 test system and get back to you. Or if you have any other suggestions, feel free to let me know.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation external contribution Code contribution from users community.. organization: streamhpc contributors from streamhpc project: rocprim

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants