2626#include < cuda/atomic>
2727#include < cuda/std/type_traits>
2828#include < thrust/distance.h>
29+ #include < thrust/execution_policy.h>
30+ #include < thrust/logical.h>
31+ #include < thrust/reduce.h>
2932#include < thrust/tuple.h>
3033#if defined(CUCO_HAS_CUDA_BARRIER)
3134#include < cuda/barrier>
@@ -1016,7 +1019,7 @@ class open_addressing_ref_impl {
10161019 InputProbeIt input_probe_end,
10171020 OutputProbeIt output_probe,
10181021 OutputMatchIt output_match,
1019- AtomicCounter& atomic_counter) const
1022+ AtomicCounter* atomic_counter) const
10201023 {
10211024 auto constexpr is_outer = false ;
10221025 auto const n = cuco::detail::distance (input_probe_begin, input_probe_end); // TODO include
@@ -1065,7 +1068,7 @@ class open_addressing_ref_impl {
10651068 InputProbeIt input_probe_end,
10661069 OutputProbeIt output_probe,
10671070 OutputMatchIt output_match,
1068- AtomicCounter& atomic_counter) const
1071+ AtomicCounter* atomic_counter) const
10691072 {
10701073 auto constexpr is_outer = true ;
10711074 auto const n = cuco::detail::distance (input_probe_begin, input_probe_end); // TODO include
@@ -1116,7 +1119,7 @@ class open_addressing_ref_impl {
11161119 cuco::detail::index_type n,
11171120 OutputProbeIt output_probe,
11181121 OutputMatchIt output_match,
1119- AtomicCounter& atomic_counter) const
1122+ AtomicCounter* atomic_counter) const
11201123 {
11211124 namespace cg = cooperative_groups;
11221125
@@ -1143,26 +1146,24 @@ class open_addressing_ref_impl {
11431146 auto const stride = probing_tile.meta_group_size ();
11441147 auto idx = probing_tile.meta_group_rank ();
11451148
1146- // TODO align to 16B?
11471149 __shared__ cuco::pair<probe_type, value_type> buffers[num_flushing_tiles][buffer_size];
1148- size_type num_matches = 0 ;
1150+ __shared__ int32_t counters[num_flushing_tiles];
1151+
1152+ if (flushing_tile.thread_rank () == 0 ) { counters[flushing_tile_id] = 0 ; }
1153+ flushing_tile.sync ();
11491154
11501155 auto flush_buffers = [&](auto const & tile) {
11511156 size_type offset = 0 ;
1152- /*
1153- if (tile.thread_rank() == 0) {
1154- offset = atomic_counter.fetch_add(num_matches, cuda::std::memory_order_relaxed);
1155- }
1156- */
1157+ auto const count = counters[flushing_tile_id];
1158+ auto const rank = tile.thread_rank ();
1159+ if (rank == 0 ) { offset = atomic_counter->fetch_add (count, cuda::memory_order_relaxed); }
11571160 offset = tile.shfl (offset, 0 );
11581161
1159- /*
11601162 // flush_buffers
1161- for (size_type i = rank; i < num_matches ; i += tile.size()) {
1163+ for (auto i = rank; i < count ; i += tile.size ()) {
11621164 *(output_probe + offset + i) = buffers[flushing_tile_id][i].first ;
11631165 *(output_match + offset + i) = buffers[flushing_tile_id][i].second ;
11641166 }
1165- */
11661167 };
11671168
11681169 while (flushing_tile.any (idx < n)) {
@@ -1176,102 +1177,116 @@ class open_addressing_ref_impl {
11761177 auto const & probe_key = *(input_probe + idx);
11771178 auto probing_iter =
11781179 this ->probing_scheme_ (probing_tile, probe_key, this ->storage_ref_ .bucket_extent ());
1179- bool running = true ;
1180- bool match_found = false ;
1181- [[maybe_unused]] bool found_any_match = false ; // only needed if `IsOuter == true`
1182-
1183- while (true ) {
1184- // TODO atomic_ref::load if insert operator is present
1185- auto const bucket_slots = this ->storage_ref_ [*probing_iter];
1186-
1187- for (int32_t i = 0 ; i < bucket_size; ++i) {
1188- if (running) {
1189- // inspect slot content
1190- switch (this ->predicate_ .operator ()<is_insert::NO>(
1191- probe_key, this ->extract_key (bucket_slots[i]))) {
1192- case detail::equal_result::EMPTY: {
1193- running = false ;
1194- break ;
1195- }
1196- case detail::equal_result::EQUAL: {
1197- if constexpr (!AllowsDuplicates) { running = false ; }
1198- match_found = true ;
1199- break ;
1200- }
1201- default : {
1202- break ;
1180+
1181+ bool running = true ;
1182+ [[maybe_unused]] bool found_match = false ;
1183+
1184+ bool equals[buffer_size];
1185+ uint32_t exists[buffer_size];
1186+
1187+ while (active_flushing_tile.any (running)) {
1188+ if (running) {
1189+ // TODO atomic_ref::load if insert operator is present
1190+ auto const bucket_slots = this ->storage_ref_ [*probing_iter];
1191+
1192+ #pragma unroll buffer_size
1193+ for (int32_t i = 0 ; i < bucket_size; ++i) {
1194+ equals[i] = false ;
1195+ if (running) {
1196+ // inspect slot content
1197+ switch (this ->predicate_ .operator ()<is_insert::NO>(
1198+ probe_key, this ->extract_key (bucket_slots[i]))) {
1199+ case detail::equal_result::EMPTY: {
1200+ running = false ;
1201+ break ;
1202+ }
1203+ case detail::equal_result::EQUAL: {
1204+ if constexpr (!AllowsDuplicates) { running = false ; }
1205+ equals[i] = true ;
1206+ break ;
1207+ }
1208+ default : {
1209+ break ;
1210+ }
12031211 }
12041212 }
12051213 }
12061214
1207- if (active_flushing_tile.any (match_found)) {
1208- auto const matching_tile = cg::binary_partition (active_flushing_tile, match_found);
1209- // stage matches in shmem buffer
1210- if (match_found) {
1211- buffers[flushing_tile_id][num_matches + matching_tile.thread_rank ()] = {
1212- probe_key, bucket_slots[i]};
1215+ probing_tile.sync ();
1216+ running = probing_tile.all (running);
1217+ #pragma unroll buffer_size
1218+ for (int32_t i = 0 ; i < bucket_size; ++i) {
1219+ exists[i] = probing_tile.ballot (equals[i]);
1220+ }
1221+
1222+ if (thrust::any_of (thrust::seq, exists, exists + bucket_size, thrust::identity{})) {
1223+ if constexpr (IsOuter) { found_match = true ; }
1224+
1225+ int32_t num_matches[bucket_size];
1226+
1227+ for (int32_t i = 0 ; i < bucket_size; ++i) {
1228+ num_matches[i] = __popc (exists[i]);
12131229 }
12141230
1215- // add number of new matches to the buffer counter
1216- num_matches += (match_found) ? matching_tile.size ()
1217- : active_flushing_tile.size () - matching_tile.size ();
1218- }
1231+ auto const total_matches =
1232+ thrust::reduce (thrust::seq, num_matches, num_matches + bucket_size);
12191233
1220- if constexpr (IsOuter) {
1221- if (not found_any_match /* yet*/ and probing_tile.any (match_found) /* now*/ ) {
1222- found_any_match = true ;
1234+ int32_t output_idx;
1235+ if (probing_tile.thread_rank () == 0 ) {
1236+ auto ref =
1237+ cuda::atomic_ref<int32_t , cuda::thread_scope_block>{counters[flushing_tile_id]};
1238+ output_idx = ref.fetch_add (total_matches, cuda::memory_order_relaxed);
1239+ }
1240+ output_idx = probing_tile.shfl (output_idx, 0 );
1241+
1242+ int32_t matche_offset = 0 ;
1243+ #pragma unroll buffer_size
1244+ for (int32_t i = 0 ; i < bucket_size; ++i) {
1245+ if (equals[i]) {
1246+ auto const lane_offset =
1247+ detail::count_least_significant_bits (exists[i], probing_tile.thread_rank ());
1248+ buffers[flushing_tile_id][output_idx + matche_offset + lane_offset] = {
1249+ probe_key, bucket_slots[i]};
1250+ }
1251+ matche_offset += num_matches[i];
12231252 }
12241253 }
12251254
1226- // reset flag for next iteration
1227- match_found = false ;
1228- }
1229- running = probing_tile.all (running);
1230-
1231- // check if all probing tiles have finished their work
1232- bool const finished = !active_flushing_tile.any (running);
1233-
1234- if constexpr (IsOuter) {
1235- if (finished) {
1236- bool const writes_sentinel =
1237- ((probing_tile.thread_rank () == 0 ) and not found_any_match);
1238-
1239- auto const sentinel_writers =
1240- cg::binary_partition (active_flushing_tile, writes_sentinel);
1241- if (writes_sentinel) {
1242- auto const rank = sentinel_writers.thread_rank ();
1243- buffers[flushing_tile_id][num_matches + rank] = {probe_key,
1244- this ->empty_slot_sentinel ()};
1255+ if constexpr (IsOuter) {
1256+ if (!running) {
1257+ if (!found_match and probing_tile.thread_rank () == 0 ) {
1258+ auto ref =
1259+ cuda::atomic_ref<int32_t , cuda::thread_scope_block>{counters[flushing_tile_id]};
1260+ auto const output_idx = ref.fetch_add (1 , cuda::memory_order_relaxed);
1261+ buffers[flushing_tile_id][output_idx] = {probe_key, this ->empty_slot_sentinel ()};
1262+ }
12451263 }
1246- // add number of new matches to the buffer counter
1247- num_matches += (writes_sentinel)
1248- ? sentinel_writers.size ()
1249- : active_flushing_tile.size () - sentinel_writers.size ();
12501264 }
1251- }
1265+ } // if running
12521266
1267+ active_flushing_tile.sync ();
12531268 // if the buffer has not enough empty slots for the next iteration
1254- if (num_matches > (buffer_size - max_matches_per_step)) {
1269+ if (counters[flushing_tile_id] > (buffer_size - max_matches_per_step)) {
12551270 flush_buffers (active_flushing_tile);
1271+ active_flushing_tile.sync ();
12561272
12571273 // reset buffer counter
1258- num_matches = 0 ;
1274+ if (active_flushing_tile.thread_rank () == 0 ) { counters[flushing_tile_id] = 0 ; }
1275+ active_flushing_tile.sync ();
12591276 }
12601277
1261- // the entire flushing tile has finished its work
1262- if (finished) { break ; }
1263-
12641278 // onto the next probing bucket
12651279 ++probing_iter;
1266- }
1267- }
1280+ } // while running
1281+ } // if active_flag
12681282
12691283 // onto the next key
12701284 idx += stride;
12711285 }
12721286
1287+ flushing_tile.sync ();
12731288 // entire flusing_tile has finished; flush remaining elements
1274- if (num_matches > 0 ) { flush_buffers (flushing_tile); }
1289+ if (counters[flushing_tile_id] > 0 ) { flush_buffers (flushing_tile); }
12751290 }
12761291
12771292 /* *
0 commit comments