Skip to content

Commit 740dbae

Browse files
committed
Fix hash table full occupancy bug
1 parent b8429d4 commit 740dbae

File tree

2 files changed

+67
-27
lines changed

2 files changed

+67
-27
lines changed

include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh

Lines changed: 59 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,10 @@ class open_addressing_ref_impl {
378378
{
379379
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
380380

381-
auto const val = this->heterogeneous_value(value);
382-
auto const key = this->extract_key(val);
383-
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
381+
auto const val = this->heterogeneous_value(value);
382+
auto const key = this->extract_key(val);
383+
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
384+
auto const init_idx = *probing_iter;
384385

385386
while (true) {
386387
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -411,6 +412,7 @@ class open_addressing_ref_impl {
411412
}
412413
}
413414
++probing_iter;
415+
if (*probing_iter == init_idx) { return false; }
414416
}
415417
}
416418

@@ -428,9 +430,10 @@ class open_addressing_ref_impl {
428430
__device__ bool insert(cooperative_groups::thread_block_tile<cg_size> const& group,
429431
Value const& value) noexcept
430432
{
431-
auto const val = this->heterogeneous_value(value);
432-
auto const key = this->extract_key(val);
433-
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
433+
auto const val = this->heterogeneous_value(value);
434+
auto const key = this->extract_key(val);
435+
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
436+
auto const init_idx = *probing_iter;
434437

435438
while (true) {
436439
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -483,6 +486,7 @@ class open_addressing_ref_impl {
483486
}
484487
} else {
485488
++probing_iter;
489+
if (*probing_iter == init_idx) { return false; }
486490
}
487491
}
488492
}
@@ -513,9 +517,10 @@ class open_addressing_ref_impl {
513517
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
514518
#endif
515519

516-
auto const val = this->heterogeneous_value(value);
517-
auto const key = this->extract_key(val);
518-
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
520+
auto const val = this->heterogeneous_value(value);
521+
auto const key = this->extract_key(val);
522+
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
523+
auto const init_idx = *probing_iter;
519524

520525
while (true) {
521526
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -554,6 +559,7 @@ class open_addressing_ref_impl {
554559
}
555560
}
556561
++probing_iter;
562+
if (*probing_iter == init_idx) { return {this->end(), false}; }
557563
};
558564
}
559565

@@ -584,9 +590,10 @@ class open_addressing_ref_impl {
584590
"insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs.");
585591
#endif
586592

587-
auto const val = this->heterogeneous_value(value);
588-
auto const key = this->extract_key(val);
589-
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
593+
auto const val = this->heterogeneous_value(value);
594+
auto const key = this->extract_key(val);
595+
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
596+
auto const init_idx = *probing_iter;
590597

591598
while (true) {
592599
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -653,6 +660,7 @@ class open_addressing_ref_impl {
653660
}
654661
} else {
655662
++probing_iter;
663+
if (*probing_iter == init_idx) { return {this->end(), false}; }
656664
}
657665
}
658666
}
@@ -671,7 +679,8 @@ class open_addressing_ref_impl {
671679
{
672680
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
673681

674-
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
682+
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
683+
auto const init_idx = *probing_iter;
675684

676685
while (true) {
677686
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -696,6 +705,7 @@ class open_addressing_ref_impl {
696705
}
697706
}
698707
++probing_iter;
708+
if (*probing_iter == init_idx) { return false; }
699709
}
700710
}
701711

@@ -713,7 +723,8 @@ class open_addressing_ref_impl {
713723
__device__ bool erase(cooperative_groups::thread_block_tile<cg_size> const& group,
714724
ProbeKey const& key) noexcept
715725
{
716-
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
726+
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
727+
auto const init_idx = *probing_iter;
717728

718729
while (true) {
719730
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -750,6 +761,7 @@ class open_addressing_ref_impl {
750761
if (group.any(state == detail::equal_result::EMPTY)) { return false; }
751762

752763
++probing_iter;
764+
if (*probing_iter == init_idx) { return false; }
753765
}
754766
}
755767

@@ -769,7 +781,8 @@ class open_addressing_ref_impl {
769781
[[nodiscard]] __device__ bool contains(ProbeKey const& key) const noexcept
770782
{
771783
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
772-
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
784+
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
785+
auto const init_idx = *probing_iter;
773786

774787
while (true) {
775788
// TODO atomic_ref::load if insert operator is present
@@ -783,6 +796,7 @@ class open_addressing_ref_impl {
783796
}
784797
}
785798
++probing_iter;
799+
if (*probing_iter == init_idx) { return false; }
786800
}
787801
}
788802

@@ -803,7 +817,8 @@ class open_addressing_ref_impl {
803817
[[nodiscard]] __device__ bool contains(
804818
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
805819
{
806-
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
820+
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
821+
auto const init_idx = *probing_iter;
807822

808823
while (true) {
809824
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -821,6 +836,7 @@ class open_addressing_ref_impl {
821836
if (group.any(state == detail::equal_result::EMPTY)) { return false; }
822837

823838
++probing_iter;
839+
if (*probing_iter == init_idx) { return false; }
824840
}
825841
}
826842

@@ -840,7 +856,8 @@ class open_addressing_ref_impl {
840856
[[nodiscard]] __device__ const_iterator find(ProbeKey const& key) const noexcept
841857
{
842858
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
843-
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
859+
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
860+
auto const init_idx = *probing_iter;
844861

845862
while (true) {
846863
// TODO atomic_ref::load if insert operator is present
@@ -859,6 +876,7 @@ class open_addressing_ref_impl {
859876
}
860877
}
861878
++probing_iter;
879+
if (*probing_iter == init_idx) { return this->end(); }
862880
}
863881
}
864882

@@ -879,7 +897,8 @@ class open_addressing_ref_impl {
879897
[[nodiscard]] __device__ const_iterator find(
880898
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
881899
{
882-
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
900+
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
901+
auto const init_idx = *probing_iter;
883902

884903
while (true) {
885904
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -908,6 +927,7 @@ class open_addressing_ref_impl {
908927
if (group.any(state == detail::equal_result::EMPTY)) { return this->end(); }
909928

910929
++probing_iter;
930+
if (*probing_iter == init_idx) { return this->end(); }
911931
}
912932
}
913933

@@ -926,8 +946,9 @@ class open_addressing_ref_impl {
926946
if constexpr (not allows_duplicates) {
927947
return static_cast<size_type>(this->contains(key));
928948
} else {
929-
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
930-
size_type count = 0;
949+
auto probing_iter = probing_scheme_(key, storage_ref_.bucket_extent());
950+
auto const init_idx = *probing_iter;
951+
size_type count = 0;
931952

932953
while (true) {
933954
// TODO atomic_ref::load if insert operator is present
@@ -942,6 +963,7 @@ class open_addressing_ref_impl {
942963
}
943964
}
944965
++probing_iter;
966+
if (*probing_iter == init_idx) { return count; }
945967
}
946968
}
947969
}
@@ -960,8 +982,9 @@ class open_addressing_ref_impl {
960982
[[nodiscard]] __device__ size_type count(
961983
cooperative_groups::thread_block_tile<cg_size> const& group, ProbeKey const& key) const noexcept
962984
{
963-
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
964-
size_type count = 0;
985+
auto probing_iter = probing_scheme_(group, key, storage_ref_.bucket_extent());
986+
auto const init_idx = *probing_iter;
987+
size_type count = 0;
965988

966989
while (true) {
967990
auto const bucket_slots = storage_ref_[*probing_iter];
@@ -978,6 +1001,7 @@ class open_addressing_ref_impl {
9781001

9791002
if (group.any(state == detail::equal_result::EMPTY)) { return count; }
9801003
++probing_iter;
1004+
if (*probing_iter == init_idx) { return count; }
9811005
}
9821006
}
9831007

@@ -1177,6 +1201,7 @@ class open_addressing_ref_impl {
11771201
auto const& probe_key = *(input_probe + idx);
11781202
auto probing_iter =
11791203
this->probing_scheme_(probing_tile, probe_key, this->storage_ref_.bucket_extent());
1204+
auto const init_idx = *probing_iter;
11801205

11811206
bool running = true;
11821207
[[maybe_unused]] bool found_match = false;
@@ -1277,6 +1302,7 @@ class open_addressing_ref_impl {
12771302

12781303
// onto the next probing bucket
12791304
++probing_iter;
1305+
if (*probing_iter == init_idx) { running = false; }
12801306
} // while running
12811307
} // if active_flag
12821308

@@ -1305,7 +1331,8 @@ class open_addressing_ref_impl {
13051331
__device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept
13061332
{
13071333
static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme");
1308-
auto probing_iter = this->probing_scheme_(key, this->storage_ref_.bucket_extent());
1334+
auto probing_iter = this->probing_scheme_(key, this->storage_ref_.bucket_extent());
1335+
auto const init_idx = *probing_iter;
13091336

13101337
while (true) {
13111338
// TODO atomic_ref::load if insert operator is present
@@ -1325,6 +1352,7 @@ class open_addressing_ref_impl {
13251352
}
13261353
}
13271354
++probing_iter;
1355+
if (*probing_iter == init_idx) { return; }
13281356
}
13291357
}
13301358

@@ -1352,8 +1380,9 @@ class open_addressing_ref_impl {
13521380
ProbeKey const& key,
13531381
CallbackOp&& callback_op) const noexcept
13541382
{
1355-
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
1356-
bool empty = false;
1383+
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
1384+
auto const init_idx = *probing_iter;
1385+
bool empty = false;
13571386

13581387
while (true) {
13591388
// TODO atomic_ref::load if insert operator is present
@@ -1378,6 +1407,7 @@ class open_addressing_ref_impl {
13781407
if (group.any(empty)) { return; }
13791408

13801409
++probing_iter;
1410+
if (*probing_iter == init_idx) { return; }
13811411
}
13821412
}
13831413

@@ -1414,8 +1444,9 @@ class open_addressing_ref_impl {
14141444
CallbackOp&& callback_op,
14151445
SyncOp&& sync_op) const noexcept
14161446
{
1417-
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
1418-
bool empty = false;
1447+
auto probing_iter = this->probing_scheme_(group, key, this->storage_ref_.bucket_extent());
1448+
auto const init_idx = *probing_iter;
1449+
bool empty = false;
14191450

14201451
while (true) {
14211452
// TODO atomic_ref::load if insert operator is present
@@ -1441,6 +1472,7 @@ class open_addressing_ref_impl {
14411472
if (group.any(empty)) { return; }
14421473

14431474
++probing_iter;
1475+
if (*probing_iter == init_idx) { return; }
14441476
}
14451477
}
14461478

include/cuco/detail/static_map/static_map_ref.inl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@ class operator_impl<
493493
auto& probing_scheme = ref_.impl_.probing_scheme();
494494
auto storage_ref = ref_.impl_.storage_ref();
495495
auto probing_iter = probing_scheme(key, storage_ref.bucket_extent());
496+
auto const init_idx = *probing_iter;
496497

497498
while (true) {
498499
auto const bucket_slots = storage_ref[*probing_iter];
@@ -514,6 +515,7 @@ class operator_impl<
514515
}
515516
}
516517
++probing_iter;
518+
if (*probing_iter == init_idx) { return; }
517519
}
518520
}
519521

@@ -539,6 +541,7 @@ class operator_impl<
539541
auto& probing_scheme = ref_.impl_.probing_scheme();
540542
auto storage_ref = ref_.impl_.storage_ref();
541543
auto probing_iter = probing_scheme(group, key, storage_ref.bucket_extent());
544+
auto const init_idx = *probing_iter;
542545

543546
while (true) {
544547
auto const bucket_slots = storage_ref[*probing_iter];
@@ -578,6 +581,7 @@ class operator_impl<
578581
if (group.shfl(status, src_lane)) { return; }
579582
} else {
580583
++probing_iter;
584+
if (*probing_iter == init_idx) { return; }
581585
}
582586
}
583587
}
@@ -855,6 +859,7 @@ class operator_impl<
855859
auto& probing_scheme = ref_.impl_.probing_scheme();
856860
auto storage_ref = ref_.impl_.storage_ref();
857861
auto probing_iter = probing_scheme(key, storage_ref.bucket_extent());
862+
auto const init_idx = *probing_iter;
858863
auto const empty_value = ref_.empty_value_sentinel();
859864

860865
// wait for payload only when init != sentinel and insert strategy is not `packed_cas`
@@ -894,6 +899,7 @@ class operator_impl<
894899
}
895900
}
896901
++probing_iter;
902+
if (*probing_iter == init_idx) { return false; }
897903
}
898904
}
899905

@@ -929,6 +935,7 @@ class operator_impl<
929935
auto& probing_scheme = ref_.impl_.probing_scheme();
930936
auto storage_ref = ref_.impl_.storage_ref();
931937
auto probing_iter = probing_scheme(group, key, storage_ref.bucket_extent());
938+
auto const init_idx = *probing_iter;
932939
auto const empty_value = ref_.empty_value_sentinel();
933940

934941
// wait for payload only when init != sentinel and insert strategy is not `packed_cas`
@@ -987,6 +994,7 @@ class operator_impl<
987994
}
988995
} else {
989996
++probing_iter;
997+
if (*probing_iter == init_idx) { return false; }
990998
}
991999
}
9921000
}

0 commit comments

Comments
 (0)