Skip to content

Commit 681cf95

Browse files
authored
Fix hash table full occupancy bug (#647)
The existing hash table implementation relies on empty slots to terminate the probing sequence, leading to hangs when inserting into or querying a fully occupied hash table. This PR resolves the issue by tracking the initial slot index for each probing key and ensuring the probing sequence terminates upon looping through all slots and returning to the starting index. Benchmark tests confirm that this change has no performance impact on non-fully occupied hash tables.
1 parent a4fb985 commit 681cf95

File tree

2 files changed

+67
-27
lines changed

2 files changed

+67
-27
lines changed

include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,10 @@ class open_addressing_ref_impl {
378378
{
379379
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
380380

381-
auto const val = this->heterogeneous_value(value);
382-
auto const key = this->extract_key(val);
383-
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
381+
auto const val = this->heterogeneous_value(value);
382+
auto const key = this->extract_key(val);
383+
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
384+
auto const init_idx = *probing_iter;
384385

385386
while (true) {
386387
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -411,6 +412,7 @@ class open_addressing_ref_impl {
411412
}
412413
}
413414
++probing_iter;
415+
if (*probing_iter == init_idx) { return false; }
414416
}
415417
}
416418

@@ -428,9 +430,10 @@ class open_addressing_ref_impl {
428430
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
429431
Value const& value) noexcept
430432
{
431-
auto const val = this->heterogeneous_value(value);
432-
auto const key = this->extract_key(val);
433-
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
433+
auto const val = this->heterogeneous_value(value);
434+
auto const key = this->extract_key(val);
435+
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
436+
auto const init_idx = *probing_iter;
434437

435438
while (true) {
436439
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -483,6 +486,7 @@ class open_addressing_ref_impl {
483486
}
484487
} else {
485488
++probing_iter;
489+
if (*probing_iter == init_idx) { return false; }
486490
}
487491
}
488492
}
@@ -513,9 +517,10 @@ class open_addressing_ref_impl {
513517
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
514518
#endif
515519

516-
auto const val = this->heterogeneous_value(value);
517-
auto const key = this->extract_key(val);
518-
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
520+
auto const val = this->heterogeneous_value(value);
521+
auto const key = this->extract_key(val);
522+
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
523+
auto const init_idx = *probing_iter;
519524

520525
while (true) {
521526
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -554,6 +559,7 @@ class open_addressing_ref_impl {
554559
}
555560
}
556561
++probing_iter;
562+
if (*probing_iter == init_idx) { return {this->end(), false}; }
557563
};
558564
}
559565

@@ -584,9 +590,10 @@ class open_addressing_ref_impl {
584590
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
585591
#endif
586592

587-
auto const val = this->heterogeneous_value(value);
588-
auto const key = this->extract_key(val);
589-
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
593+
auto const val = this->heterogeneous_value(value);
594+
auto const key = this->extract_key(val);
595+
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
596+
auto const init_idx = *probing_iter;
590597

591598
while (true) {
592599
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -653,6 +660,7 @@ class open_addressing_ref_impl {
653660
}
654661
} else {
655662
++probing_iter;
663+
if (*probing_iter == init_idx) { return {this->end(), false}; }
656664
}
657665
}
658666
}
@@ -671,7 +679,8 @@ class open_addressing_ref_impl {
671679
{
672680
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
673681

674-
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
682+
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
683+
auto const init_idx = *probing_iter;
675684

676685
while (true) {
677686
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -696,6 +705,7 @@ class open_addressing_ref_impl {
696705
}
697706
}
698707
++probing_iter;
708+
if (*probing_iter == init_idx) { return false; }
699709
}
700710
}
701711

@@ -713,7 +723,8 @@ class open_addressing_ref_impl {
713723
__device__ bool erase(cooperative_groups::thread_block_tile<cg_size> const& group,
714724
ProbeKey const& key) noexcept
715725
{
716-
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
726+
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
727+
auto const init_idx = *probing_iter;
717728

718729
while (true) {
719730
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -750,6 +761,7 @@ class open_addressing_ref_impl {
750761
if (group.any(state == detail::equal_result::EMPTY)) { return false; }
751762

752763
++probing_iter;
764+
if (*probing_iter == init_idx) { return false; }
753765
}
754766
}
755767

@@ -769,7 +781,8 @@ class open_addressing_ref_impl {
769781
[[nodiscard]] __device__ bool contains(ProbeKey const& key) const noexcept
770782
{
771783
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
772-
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
784+
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
785+
auto const init_idx = *probing_iter;
773786

774787
while (true) {
775788
// TODO atomic_ref::load if insert operator is present
@@ -783,6 +796,7 @@ class open_addressing_ref_impl {
783796
}
784797
}
785798
++probing_iter;
799+
if (*probing_iter == init_idx) { return false; }
786800
}
787801
}
788802

@@ -803,7 +817,8 @@ class open_addressing_ref_impl {
803817
[[nodiscard]] __device__ bool contains(
804818
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
805819
{
806-
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
820+
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
821+
auto const init_idx = *probing_iter;
807822

808823
while (true) {
809824
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -821,6 +836,7 @@ class open_addressing_ref_impl {
821836
if (group.any(state == detail::equal_result::EMPTY)) { return false; }
822837

823838
++probing_iter;
839+
if (*probing_iter == init_idx) { return false; }
824840
}
825841
}
826842

@@ -840,7 +856,8 @@ class open_addressing_ref_impl {
840856
[[nodiscard]] __device__ const_iterator find(ProbeKey const& key) const noexcept
841857
{
842858
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
843-
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
859+
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
860+
auto const init_idx = *probing_iter;
844861

845862
while (true) {
846863
// TODO atomic_ref::load if insert operator is present
@@ -859,6 +876,7 @@ class open_addressing_ref_impl {
859876
}
860877
}
861878
++probing_iter;
879+
if (*probing_iter == init_idx) { return this->end(); }
862880
}
863881
}
864882

@@ -879,7 +897,8 @@ class open_addressing_ref_impl {
879897
[[nodiscard]] __device__ const_iterator find(
880898
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
881899
{
882-
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
900+
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
901+
auto const init_idx = *probing_iter;
883902

884903
while (true) {
885904
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -908,6 +927,7 @@ class open_addressing_ref_impl {
908927
if (group.any(state == detail::equal_result::EMPTY)) { return this->end(); }
909928

910929
++probing_iter;
930+
if (*probing_iter == init_idx) { return this->end(); }
911931
}
912932
}
913933

@@ -926,8 +946,9 @@ class open_addressing_ref_impl {
926946
if constexpr (not allows_duplicates) {
927947
return static_cast<size_type>(this->contains(key));
928948
} else {
929-
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
930-
size_type count = 0;
949+
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
950+
auto const init_idx = *probing_iter;
951+
size_type count = 0;
931952

932953
while (true) {
933954
// TODO atomic_ref::load if insert operator is present
@@ -942,6 +963,7 @@ class open_addressing_ref_impl {
942963
}
943964
}
944965
++probing_iter;
966+
if (*probing_iter == init_idx) { return count; }
945967
}
946968
}
947969
}
@@ -960,8 +982,9 @@ class open_addressing_ref_impl {
960982
[[nodiscard]] __device__ size_type count(
961983
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
962984
{
963-
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
964-
size_type count = 0;
985+
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
986+
auto const init_idx = *probing_iter;
987+
size_type count = 0;
965988

966989
while (true) {
967990
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -978,6 +1001,7 @@ class open_addressing_ref_impl {
9781001

9791002
if (group.any(state == detail::equal_result::EMPTY)) { return count; }
9801003
++probing_iter;
1004+
if (*probing_iter == init_idx) { return count; }
9811005
}
9821006
}
9831007

@@ -1177,6 +1201,7 @@ class open_addressing_ref_impl {
11771201
auto const& probe_key = *(input_probe + idx);
11781202
auto probing_iter =
11791203
this->probing_scheme_(probing_tile, probe_key, this->storage_ref_.bucket_extent());
1204+
auto const init_idx = *probing_iter;
11801205

11811206
bool running = true;
11821207
[[maybe_unused]] bool found_match = false;
@@ -1277,6 +1302,7 @@ class open_addressing_ref_impl {
12771302

12781303
// onto the next probing bucket
12791304
++probing_iter;
1305+
if (*probing_iter == init_idx) { running = false; }
12801306
} // while running
12811307
} // if active_flag
12821308

@@ -1305,7 +1331,8 @@ class open_addressing_ref_impl {
13051331
__device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept
13061332
{
13071333
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
1308-
auto probing_iter = this->probing_scheme_(key, this->storage_ref_.bucket_extent());
1334+
auto probing_iter = this->probing_scheme_(key, this->storage_ref_.bucket_extent());
1335+
auto const init_idx = *probing_iter;
13091336

13101337
while (true) {
13111338
// TODO atomic_ref::load if insert operator is present
@@ -1325,6 +1352,7 @@ class open_addressing_ref_impl {
13251352
}
13261353
}
13271354
++probing_iter;
1355+
if (*probing_iter == init_idx) { return; }
13281356
}
13291357
}
13301358

@@ -1352,8 +1380,9 @@ class open_addressing_ref_impl {
13521380
ProbeKey const& key,
13531381
CallbackOp&& callback_op) const noexcept
13541382
{
1355-
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
1356-
bool empty = false;
1383+
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
1384+
auto const init_idx = *probing_iter;
1385+
bool empty = false;
13571386

13581387
while (true) {
13591388
// TODO atomic_ref::load if insert operator is present
@@ -1378,6 +1407,7 @@ class open_addressing_ref_impl {
13781407
if (group.any(empty)) { return; }
13791408

13801409
++probing_iter;
1410+
if (*probing_iter == init_idx) { return; }
13811411
}
13821412
}
13831413

@@ -1414,8 +1444,9 @@ class open_addressing_ref_impl {
14141444
CallbackOp&& callback_op,
14151445
SyncOp&& sync_op) const noexcept
14161446
{
1417-
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
1418-
bool empty = false;
1447+
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
1448+
auto const init_idx = *probing_iter;
1449+
bool empty = false;
14191450

14201451
while (true) {
14211452
// TODO atomic_ref::load if insert operator is present
@@ -1441,6 +1472,7 @@ class open_addressing_ref_impl {
14411472
if (group.any(empty)) { return; }
14421473

14431474
++probing_iter;
1475+
if (*probing_iter == init_idx) { return; }
14441476
}
14451477
}
14461478

include/cuco/detail/static_map/static_map_ref.inl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@ class operator_impl<
493493
auto& probing_scheme = ref_.impl_.probing_scheme();
494494
auto storage_ref = ref_.impl_.storage_ref();
495495
auto probing_iter = probing_scheme(key, storage_ref.bucket_extent());
496+
auto const init_idx = *probing_iter;
496497

497498
while (true) {
498499
auto const bucket_slots = storage_ref[*probing_iter];
@@ -514,6 +515,7 @@ class operator_impl<
514515
}
515516
}
516517
++probing_iter;
518+
if (*probing_iter == init_idx) { return; }
517519
}
518520
}
519521

@@ -539,6 +541,7 @@ class operator_impl<
539541
auto& probing_scheme = ref_.impl_.probing_scheme();
540542
auto storage_ref = ref_.impl_.storage_ref();
541543
auto probing_iter = probing_scheme(group, key, storage_ref.bucket_extent());
544+
auto const init_idx = *probing_iter;
542545

543546
while (true) {
544547
auto const bucket_slots = storage_ref[*probing_iter];
@@ -578,6 +581,7 @@ class operator_impl<
578581
if (group.shfl(status, src_lane)) { return; }
579582
} else {
580583
++probing_iter;
584+
if (*probing_iter == init_idx) { return; }
581585
}
582586
}
583587
}
@@ -855,6 +859,7 @@ class operator_impl<
855859
auto& probing_scheme = ref_.impl_.probing_scheme();
856860
auto storage_ref = ref_.impl_.storage_ref();
857861
auto probing_iter = probing_scheme(key, storage_ref.bucket_extent());
862+
auto const init_idx = *probing_iter;
858863
auto const empty_value = ref_.empty_value_sentinel();
859864

860865
// wait for payload only when init != sentinel and insert strategy is not `packed_cas`
@@ -894,6 +899,7 @@ class operator_impl<
894899
}
895900
}
896901
++probing_iter;
902+
if (*probing_iter == init_idx) { return false; }
897903
}
898904
}
899905

@@ -929,6 +935,7 @@ class operator_impl<
929935
auto& probing_scheme = ref_.impl_.probing_scheme();
930936
auto storage_ref = ref_.impl_.storage_ref();
931937
auto probing_iter = probing_scheme(group, key, storage_ref.bucket_extent());
938+
auto const init_idx = *probing_iter;
932939
auto const empty_value = ref_.empty_value_sentinel();
933940

934941
// wait for payload only when init != sentinel and insert strategy is not `packed_cas`
@@ -987,6 +994,7 @@ class operator_impl<
987994
}
988995
} else {
989996
++probing_iter;
997+
if (*probing_iter == init_idx) { return false; }
990998
}
991999
}
9921000
}

0 commit comments

Comments
 (0)