Skip to content

Commit 4a731b8

Browse files
authored
Optimize CG-based device insert with erase check for improved performance (#681)
This PR enhances the CG-based device insertion by introducing an additional `SupportsErase` build-time check. When erasure is not required, the new implementation leverages this flag to select a more efficient code path, ensuring comparisons are made against an empty sentinel without loading the target slot's content into the CAS operation. This optimization gets rid of excessive local memory transactions, resulting in a 10~30% improvement in multimap performance. This PR also updates the multimap insert and count benchmarks to run with the new implementations. Unblocking rapidsai/cudf#18021
1 parent e783a05 commit 4a731b8

File tree

9 files changed

+55
-24
lines changed

9 files changed

+55
-24
lines changed

benchmarks/static_multimap/count_bench.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2021-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -57,12 +57,12 @@ std::enable_if_t<(sizeof(Key) == sizeof(Value)), void> static_multimap_count(
5757

5858
state.add_element_count(num_keys);
5959

60-
cuco::static_multimap<Key, Value> map{
60+
cuco::experimental::static_multimap<Key, Value> map{
6161
size, cuco::empty_key<Key>{-1}, cuco::empty_value<Value>{-1}};
6262
map.insert(pairs.begin(), pairs.end());
6363

6464
state.exec(nvbench::exec_tag::sync, [&](nvbench::launch& launch) {
65-
auto count = map.count(keys.begin(), keys.end(), launch.get_stream());
65+
auto count = map.count(keys.begin(), keys.end(), {launch.get_stream()});
6666
});
6767
}
6868

benchmarks/static_multimap/insert_bench.cu

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2021-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -56,11 +56,18 @@ std::enable_if_t<(sizeof(Key) == sizeof(Value)), void> static_multimap_insert(
5656

5757
state.exec(nvbench::exec_tag::sync | nvbench::exec_tag::timer,
5858
[&](nvbench::launch& launch, auto& timer) {
59-
cuco::static_multimap<Key, Value> map{
60-
size, cuco::empty_key<Key>{-1}, cuco::empty_value<Value>{-1}, launch.get_stream()};
59+
cuco::experimental::static_multimap<Key, Value> map{size,
60+
cuco::empty_key<Key>{-1},
61+
cuco::empty_value<Value>{-1},
62+
{},
63+
{},
64+
{},
65+
{},
66+
{},
67+
{launch.get_stream()}};
6168

6269
timer.start();
63-
map.insert(pairs.begin(), pairs.end(), launch.get_stream());
70+
map.insert(pairs.begin(), pairs.end(), {launch.get_stream()});
6471
timer.stop();
6572
});
6673
}

include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ class open_addressing_ref_impl {
426426
*
427427
* @return True if the given element is successfully inserted
428428
*/
429-
template <typename Value>
429+
template <bool SupportsErase, typename Value>
430430
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
431431
Value const& value) noexcept
432432
{
@@ -466,12 +466,20 @@ class open_addressing_ref_impl {
466466
auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE);
467467
if (group_contains_available) {
468468
auto const src_lane = __ffs(group_contains_available) - 1;
469-
auto const status =
470-
(group.thread_rank() == src_lane)
471-
? attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_bucket_index,
469+
auto status = insert_result::CONTINUE;
470+
if (group.thread_rank() == src_lane) {
471+
if constexpr (SupportsErase) {
472+
status =
473+
attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_bucket_index,
472474
bucket_slots[intra_bucket_index],
473-
val)
474-
: insert_result::CONTINUE;
475+
val);
476+
} else {
477+
status =
478+
attempt_insert((storage_ref_.data() + *probing_iter)->data() + intra_bucket_index,
479+
this->empty_slot_sentinel(),
480+
val);
481+
}
482+
}
475483

476484
switch (group.shfl(status, src_lane)) {
477485
case insert_result::SUCCESS: return true;

include/cuco/detail/operator.inl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -48,7 +48,7 @@ class operator_impl {
4848
* @return `true` if `Operator` is contained in `Operators`, `false` otherwise.
4949
*/
5050
template <typename Operator, typename... Operators>
51-
static constexpr bool has_operator()
51+
__host__ __device__ static constexpr bool has_operator()
5252
{
5353
return ((std::is_same_v<Operators, Operator>) || ...);
5454
}

include/cuco/detail/static_map/static_map_ref.inl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,11 @@ class operator_impl<
449449
Value const& value) noexcept
450450
{
451451
auto& ref_ = static_cast<ref_type&>(*this);
452-
return ref_.impl_.insert(group, value);
452+
if (ref_.erased_key_sentinel() != ref_.empty_key_sentinel()) {
453+
return ref_.impl_.insert<true>(group, value);
454+
} else {
455+
return ref_.impl_.insert<false>(group, value);
456+
}
453457
}
454458
};
455459

include/cuco/detail/static_multimap/static_multimap_ref.inl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -455,7 +455,11 @@ class operator_impl<
455455
Value const& value) noexcept
456456
{
457457
auto& ref_ = static_cast<ref_type&>(*this);
458-
return ref_.impl_.insert(group, value);
458+
if (ref_.erased_key_sentinel() != ref_.empty_key_sentinel()) {
459+
return ref_.impl_.insert<true>(group, value);
460+
} else {
461+
return ref_.impl_.insert<false>(group, value);
462+
}
459463
}
460464
};
461465

include/cuco/detail/static_multiset/static_multiset_ref.inl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2024-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -358,7 +358,11 @@ class operator_impl<
358358
Value const& value) noexcept
359359
{
360360
auto& ref_ = static_cast<ref_type&>(*this);
361-
return ref_.impl_.insert(group, value);
361+
if (ref_.erased_key_sentinel() != ref_.empty_key_sentinel()) {
362+
return ref_.impl_.insert<true>(group, value);
363+
} else {
364+
return ref_.impl_.insert<false>(group, value);
365+
}
362366
}
363367
};
364368

include/cuco/detail/static_set/static_set_ref.inl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
2+
* Copyright (c) 2022-2025, NVIDIA CORPORATION.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -391,7 +391,11 @@ class operator_impl<op::insert_tag,
391391
Value const& value) noexcept
392392
{
393393
auto& ref_ = static_cast<ref_type&>(*this);
394-
return ref_.impl_.insert(group, value);
394+
if (ref_.erased_key_sentinel() != ref_.empty_key_sentinel()) {
395+
return ref_.impl_.insert<true>(group, value);
396+
} else {
397+
return ref_.impl_.insert<false>(group, value);
398+
}
395399
}
396400
};
397401

include/cuco/static_multimap.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ template <class Key,
9393
class Extent = cuco::extent<std::size_t>,
9494
cuda::thread_scope Scope = cuda::thread_scope_device,
9595
class KeyEqual = thrust::equal_to<Key>,
96-
class ProbingScheme = cuco::linear_probing<4, // CG size
96+
class ProbingScheme = cuco::double_hashing<8, // CG size
9797
cuco::default_hash_function<Key>>,
9898
class Allocator = cuco::cuda_allocator<cuco::pair<Key, T>>,
99-
class Storage = cuco::storage<1>>
99+
class Storage = cuco::storage<2>>
100100
class static_multimap {
101101
static_assert(sizeof(Key) <= 8, "Container does not support key types larger than 8 bytes.");
102102

0 commit comments

Comments
 (0)