@@ -378,9 +378,10 @@ class open_addressing_ref_impl {
378378 {
379379 static_assert (cg_size == 1 , " Non-CG operation is incompatible with the current probing scheme" );
380380
381- auto const val = this ->heterogeneous_value (value);
382- auto const key = this ->extract_key (val);
383- auto probing_iter = probing_scheme_ (key, storage_ref_.bucket_extent ());
381+ auto const val = this ->heterogeneous_value (value);
382+ auto const key = this ->extract_key (val);
383+ auto probing_iter = probing_scheme_ (key, storage_ref_.bucket_extent ());
384+ auto const init_idx = *probing_iter;
384385
385386 while (true ) {
386387 auto const bucket_slots = storage_ref_[*probing_iter];
@@ -411,6 +412,7 @@ class open_addressing_ref_impl {
411412 }
412413 }
413414 ++probing_iter;
415+ if (*probing_iter == init_idx) { return false ; }
414416 }
415417 }
416418
@@ -428,9 +430,10 @@ class open_addressing_ref_impl {
428430 __device__ bool insert (cooperative_groups::thread_block_tile<cg_size> const & group,
429431 Value const & value) noexcept
430432 {
431- auto const val = this ->heterogeneous_value (value);
432- auto const key = this ->extract_key (val);
433- auto probing_iter = probing_scheme_ (group, key, storage_ref_.bucket_extent ());
433+ auto const val = this ->heterogeneous_value (value);
434+ auto const key = this ->extract_key (val);
435+ auto probing_iter = probing_scheme_ (group, key, storage_ref_.bucket_extent ());
436+ auto const init_idx = *probing_iter;
434437
435438 while (true ) {
436439 auto const bucket_slots = storage_ref_[*probing_iter];
@@ -483,6 +486,7 @@ class open_addressing_ref_impl {
483486 }
484487 } else {
485488 ++probing_iter;
489+ if (*probing_iter == init_idx) { return false ; }
486490 }
487491 }
488492 }
@@ -513,9 +517,10 @@ class open_addressing_ref_impl {
513517 " insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs." );
514518#endif
515519
516- auto const val = this ->heterogeneous_value (value);
517- auto const key = this ->extract_key (val);
518- auto probing_iter = probing_scheme_ (key, storage_ref_.bucket_extent ());
520+ auto const val = this ->heterogeneous_value (value);
521+ auto const key = this ->extract_key (val);
522+ auto probing_iter = probing_scheme_ (key, storage_ref_.bucket_extent ());
523+ auto const init_idx = *probing_iter;
519524
520525 while (true ) {
521526 auto const bucket_slots = storage_ref_[*probing_iter];
@@ -554,6 +559,7 @@ class open_addressing_ref_impl {
554559 }
555560 }
556561 ++probing_iter;
562+ if (*probing_iter == init_idx) { return {this ->end (), false }; }
557563 };
558564 }
559565
@@ -584,9 +590,10 @@ class open_addressing_ref_impl {
584590 " insert_and_find is not supported for pair types larger than 8 bytes on pre-Volta GPUs." );
585591#endif
586592
587- auto const val = this ->heterogeneous_value (value);
588- auto const key = this ->extract_key (val);
589- auto probing_iter = probing_scheme_ (group, key, storage_ref_.bucket_extent ());
593+ auto const val = this ->heterogeneous_value (value);
594+ auto const key = this ->extract_key (val);
595+ auto probing_iter = probing_scheme_ (group, key, storage_ref_.bucket_extent ());
596+ auto const init_idx = *probing_iter;
590597
591598 while (true ) {
592599 auto const bucket_slots = storage_ref_[*probing_iter];
@@ -653,6 +660,7 @@ class open_addressing_ref_impl {
653660 }
654661 } else {
655662 ++probing_iter;
663+ if (*probing_iter == init_idx) { return {this ->end (), false }; }
656664 }
657665 }
658666 }
@@ -671,7 +679,8 @@ class open_addressing_ref_impl {
671679 {
672680 static_assert (cg_size == 1 , " Non-CG operation is incompatible with the current probing scheme" );
673681
674- auto probing_iter = probing_scheme_ (key, storage_ref_.bucket_extent ());
682+ auto probing_iter = probing_scheme_ (key, storage_ref_.bucket_extent ());
683+ auto const init_idx = *probing_iter;
675684
676685 while (true ) {
677686 auto const bucket_slots = storage_ref_[*probing_iter];
@@ -696,6 +705,7 @@ class open_addressing_ref_impl {
696705 }
697706 }
698707 ++probing_iter;
708+ if (*probing_iter == init_idx) { return false ; }
699709 }
700710 }
701711
@@ -713,7 +723,8 @@ class open_addressing_ref_impl {
713723 __device__ bool erase (cooperative_groups::thread_block_tile<cg_size> const & group,
714724 ProbeKey const & key) noexcept
715725 {
716- auto probing_iter = probing_scheme_ (group, key, storage_ref_.bucket_extent ());
726+ auto probing_iter = probing_scheme_ (group, key, storage_ref_.bucket_extent ());
727+ auto const init_idx = *probing_iter;
717728
718729 while (true ) {
719730 auto const bucket_slots = storage_ref_[*probing_iter];
@@ -750,6 +761,7 @@ class open_addressing_ref_impl {
750761 if (group.any (state == detail::equal_result::EMPTY)) { return false ; }
751762
752763 ++probing_iter;
764+ if (*probing_iter == init_idx) { return false ; }
753765 }
754766 }
755767
@@ -769,7 +781,8 @@ class open_addressing_ref_impl {
769781 [[nodiscard]] __device__ bool contains (ProbeKey const & key) const noexcept
770782 {
771783 static_assert (cg_size == 1 , " Non-CG operation is incompatible with the current probing scheme" );
772- auto probing_iter = probing_scheme_ (key, storage_ref_.bucket_extent ());
784+ auto probing_iter = probing_scheme_ (key, storage_ref_.bucket_extent ());
785+ auto const init_idx = *probing_iter;
773786
774787 while (true ) {
775788 // TODO atomic_ref::load if insert operator is present
@@ -783,6 +796,7 @@ class open_addressing_ref_impl {
783796 }
784797 }
785798 ++probing_iter;
799+ if (*probing_iter == init_idx) { return false ; }
786800 }
787801 }
788802
@@ -803,7 +817,8 @@ class open_addressing_ref_impl {
803817 [[nodiscard]] __device__ bool contains (
804818 cooperative_groups::thread_block_tile<cg_size> const & group, ProbeKey const & key) const noexcept
805819 {
806- auto probing_iter = probing_scheme_ (group, key, storage_ref_.bucket_extent ());
820+ auto probing_iter = probing_scheme_ (group, key, storage_ref_.bucket_extent ());
821+ auto const init_idx = *probing_iter;
807822
808823 while (true ) {
809824 auto const bucket_slots = storage_ref_[*probing_iter];
@@ -821,6 +836,7 @@ class open_addressing_ref_impl {
821836 if (group.any (state == detail::equal_result::EMPTY)) { return false ; }
822837
823838 ++probing_iter;
839+ if (*probing_iter == init_idx) { return false ; }
824840 }
825841 }
826842
@@ -840,7 +856,8 @@ class open_addressing_ref_impl {
840856 [[nodiscard]] __device__ const_iterator find (ProbeKey const & key) const noexcept
841857 {
842858 static_assert (cg_size == 1 , " Non-CG operation is incompatible with the current probing scheme" );
843- auto probing_iter = probing_scheme_ (key, storage_ref_.bucket_extent ());
859+ auto probing_iter = probing_scheme_ (key, storage_ref_.bucket_extent ());
860+ auto const init_idx = *probing_iter;
844861
845862 while (true ) {
846863 // TODO atomic_ref::load if insert operator is present
@@ -859,6 +876,7 @@ class open_addressing_ref_impl {
859876 }
860877 }
861878 ++probing_iter;
879+ if (*probing_iter == init_idx) { return this ->end (); }
862880 }
863881 }
864882
@@ -879,7 +897,8 @@ class open_addressing_ref_impl {
879897 [[nodiscard]] __device__ const_iterator find (
880898 cooperative_groups::thread_block_tile<cg_size> const & group, ProbeKey const & key) const noexcept
881899 {
882- auto probing_iter = probing_scheme_ (group, key, storage_ref_.bucket_extent ());
900+ auto probing_iter = probing_scheme_ (group, key, storage_ref_.bucket_extent ());
901+ auto const init_idx = *probing_iter;
883902
884903 while (true ) {
885904 auto const bucket_slots = storage_ref_[*probing_iter];
@@ -908,6 +927,7 @@ class open_addressing_ref_impl {
908927 if (group.any (state == detail::equal_result::EMPTY)) { return this ->end (); }
909928
910929 ++probing_iter;
930+ if (*probing_iter == init_idx) { return this ->end (); }
911931 }
912932 }
913933
@@ -926,8 +946,9 @@ class open_addressing_ref_impl {
926946 if constexpr (not allows_duplicates) {
927947 return static_cast <size_type>(this ->contains (key));
928948 } else {
929- auto probing_iter = probing_scheme_ (key, storage_ref_.bucket_extent ());
930- size_type count = 0 ;
949+ auto probing_iter = probing_scheme_ (key, storage_ref_.bucket_extent ());
950+ auto const init_idx = *probing_iter;
951+ size_type count = 0 ;
931952
932953 while (true ) {
933954 // TODO atomic_ref::load if insert operator is present
@@ -942,6 +963,7 @@ class open_addressing_ref_impl {
942963 }
943964 }
944965 ++probing_iter;
966+ if (*probing_iter == init_idx) { return count; }
945967 }
946968 }
947969 }
@@ -960,8 +982,9 @@ class open_addressing_ref_impl {
960982 [[nodiscard]] __device__ size_type count (
961983 cooperative_groups::thread_block_tile<cg_size> const & group, ProbeKey const & key) const noexcept
962984 {
963- auto probing_iter = probing_scheme_ (group, key, storage_ref_.bucket_extent ());
964- size_type count = 0 ;
985+ auto probing_iter = probing_scheme_ (group, key, storage_ref_.bucket_extent ());
986+ auto const init_idx = *probing_iter;
987+ size_type count = 0 ;
965988
966989 while (true ) {
967990 auto const bucket_slots = storage_ref_[*probing_iter];
@@ -978,6 +1001,7 @@ class open_addressing_ref_impl {
9781001
9791002 if (group.any (state == detail::equal_result::EMPTY)) { return count; }
9801003 ++probing_iter;
1004+ if (*probing_iter == init_idx) { return count; }
9811005 }
9821006 }
9831007
@@ -1177,6 +1201,7 @@ class open_addressing_ref_impl {
11771201 auto const & probe_key = *(input_probe + idx);
11781202 auto probing_iter =
11791203 this ->probing_scheme_ (probing_tile, probe_key, this ->storage_ref_ .bucket_extent ());
1204+ auto const init_idx = *probing_iter;
11801205
11811206 bool running = true ;
11821207 [[maybe_unused]] bool found_match = false ;
@@ -1277,6 +1302,7 @@ class open_addressing_ref_impl {
12771302
12781303 // onto the next probing bucket
12791304 ++probing_iter;
1305+ if (*probing_iter == init_idx) { running = false ; }
12801306 } // while running
12811307 } // if active_flag
12821308
@@ -1305,7 +1331,8 @@ class open_addressing_ref_impl {
13051331 __device__ void for_each (ProbeKey const & key, CallbackOp&& callback_op) const noexcept
13061332 {
13071333 static_assert (cg_size == 1 , " Non-CG operation is incompatible with the current probing scheme" );
1308- auto probing_iter = this ->probing_scheme_ (key, this ->storage_ref_ .bucket_extent ());
1334+ auto probing_iter = this ->probing_scheme_ (key, this ->storage_ref_ .bucket_extent ());
1335+ auto const init_idx = *probing_iter;
13091336
13101337 while (true ) {
13111338 // TODO atomic_ref::load if insert operator is present
@@ -1325,6 +1352,7 @@ class open_addressing_ref_impl {
13251352 }
13261353 }
13271354 ++probing_iter;
1355+ if (*probing_iter == init_idx) { return ; }
13281356 }
13291357 }
13301358
@@ -1352,8 +1380,9 @@ class open_addressing_ref_impl {
13521380 ProbeKey const & key,
13531381 CallbackOp&& callback_op) const noexcept
13541382 {
1355- auto probing_iter = this ->probing_scheme_ (group, key, this ->storage_ref_ .bucket_extent ());
1356- bool empty = false ;
1383+ auto probing_iter = this ->probing_scheme_ (group, key, this ->storage_ref_ .bucket_extent ());
1384+ auto const init_idx = *probing_iter;
1385+ bool empty = false ;
13571386
13581387 while (true ) {
13591388 // TODO atomic_ref::load if insert operator is present
@@ -1378,6 +1407,7 @@ class open_addressing_ref_impl {
13781407 if (group.any (empty)) { return ; }
13791408
13801409 ++probing_iter;
1410+ if (*probing_iter == init_idx) { return ; }
13811411 }
13821412 }
13831413
@@ -1414,8 +1444,9 @@ class open_addressing_ref_impl {
14141444 CallbackOp&& callback_op,
14151445 SyncOp&& sync_op) const noexcept
14161446 {
1417- auto probing_iter = this ->probing_scheme_ (group, key, this ->storage_ref_ .bucket_extent ());
1418- bool empty = false ;
1447+ auto probing_iter = this ->probing_scheme_ (group, key, this ->storage_ref_ .bucket_extent ());
1448+ auto const init_idx = *probing_iter;
1449+ bool empty = false ;
14191450
14201451 while (true ) {
14211452 // TODO atomic_ref::load if insert operator is present
@@ -1441,6 +1472,7 @@ class open_addressing_ref_impl {
14411472 if (group.any (empty)) { return ; }
14421473
14431474 ++probing_iter;
1475+ if (*probing_iter == init_idx) { return ; }
14441476 }
14451477 }
14461478
0 commit comments