Skip to content

Commit edd129e

Browse files
committed
Improve OA retrieve implementations
1 parent 5db1066 commit edd129e

File tree

1 file changed

+97
-82
lines changed

1 file changed

+97
-82
lines changed

include/cuco/detail/open_addressing/open_addressing_ref_impl.cuh

Lines changed: 97 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
#include <cuda/atomic>
2727
#include <cuda/std/type_traits>
2828
#include <thrust/distance.h>
29+
#include <thrust/execution_policy.h>
30+
#include <thrust/logical.h>
31+
#include <thrust/reduce.h>
2932
#include <thrust/tuple.h>
3033
#if defined(CUCO_HAS_CUDA_BARRIER)
3134
#include <cuda/barrier>
@@ -1016,7 +1019,7 @@ class open_addressing_ref_impl {
10161019
InputProbeIt input_probe_end,
10171020
OutputProbeIt output_probe,
10181021
OutputMatchIt output_match,
1019-
AtomicCounter& atomic_counter) const
1022+
AtomicCounter* atomic_counter) const
10201023
{
10211024
auto constexpr is_outer = false;
10221025
auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); // TODO include
@@ -1065,7 +1068,7 @@ class open_addressing_ref_impl {
10651068
InputProbeIt input_probe_end,
10661069
OutputProbeIt output_probe,
10671070
OutputMatchIt output_match,
1068-
AtomicCounter& atomic_counter) const
1071+
AtomicCounter* atomic_counter) const
10691072
{
10701073
auto constexpr is_outer = true;
10711074
auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); // TODO include
@@ -1116,7 +1119,7 @@ class open_addressing_ref_impl {
11161119
cuco::detail::index_type n,
11171120
OutputProbeIt output_probe,
11181121
OutputMatchIt output_match,
1119-
AtomicCounter& atomic_counter) const
1122+
AtomicCounter* atomic_counter) const
11201123
{
11211124
namespace cg = cooperative_groups;
11221125

@@ -1143,26 +1146,24 @@ class open_addressing_ref_impl {
11431146
auto const stride = probing_tile.meta_group_size();
11441147
auto idx = probing_tile.meta_group_rank();
11451148

1146-
// TODO align to 16B?
11471149
__shared__ cuco::pair<probe_type, value_type> buffers[num_flushing_tiles][buffer_size];
1148-
size_type num_matches = 0;
1150+
__shared__ int32_t counters[num_flushing_tiles];
1151+
1152+
if (flushing_tile.thread_rank() == 0) { counters[flushing_tile_id] = 0; }
1153+
flushing_tile.sync();
11491154

11501155
auto flush_buffers = [&](auto const& tile) {
11511156
size_type offset = 0;
1152-
/*
1153-
if (tile.thread_rank() == 0) {
1154-
offset = atomic_counter.fetch_add(num_matches, cuda::std::memory_order_relaxed);
1155-
}
1156-
*/
1157+
auto const count = counters[flushing_tile_id];
1158+
auto const rank = tile.thread_rank();
1159+
if (rank == 0) { offset = atomic_counter->fetch_add(count, cuda::memory_order_relaxed); }
11571160
offset = tile.shfl(offset, 0);
11581161

1159-
/*
11601162
// flush_buffers
1161-
for (size_type i = rank; i < num_matches; i += tile.size()) {
1163+
for (auto i = rank; i < count; i += tile.size()) {
11621164
*(output_probe + offset + i) = buffers[flushing_tile_id][i].first;
11631165
*(output_match + offset + i) = buffers[flushing_tile_id][i].second;
11641166
}
1165-
*/
11661167
};
11671168

11681169
while (flushing_tile.any(idx < n)) {
@@ -1176,102 +1177,116 @@ class open_addressing_ref_impl {
11761177
auto const& probe_key = *(input_probe + idx);
11771178
auto probing_iter =
11781179
this->probing_scheme_(probing_tile, probe_key, this->storage_ref_.bucket_extent());
1179-
bool running = true;
1180-
bool match_found = false;
1181-
[[maybe_unused]] bool found_any_match = false; // only needed if `IsOuter == true`
1182-
1183-
while (true) {
1184-
// TODO atomic_ref::load if insert operator is present
1185-
auto const bucket_slots = this->storage_ref_[*probing_iter];
1186-
1187-
for (int32_t i = 0; i < bucket_size; ++i) {
1188-
if (running) {
1189-
// inspect slot content
1190-
switch (this->predicate_.operator()<is_insert::NO>(
1191-
probe_key, this->extract_key(bucket_slots[i]))) {
1192-
case detail::equal_result::EMPTY: {
1193-
running = false;
1194-
break;
1195-
}
1196-
case detail::equal_result::EQUAL: {
1197-
if constexpr (!AllowsDuplicates) { running = false; }
1198-
match_found = true;
1199-
break;
1200-
}
1201-
default: {
1202-
break;
1180+
1181+
bool running = true;
1182+
[[maybe_unused]] bool found_match = false;
1183+
1184+
bool equals[buffer_size];
1185+
uint32_t exists[buffer_size];
1186+
1187+
while (active_flushing_tile.any(running)) {
1188+
if (running) {
1189+
// TODO atomic_ref::load if insert operator is present
1190+
auto const bucket_slots = this->storage_ref_[*probing_iter];
1191+
1192+
#pragma unroll buffer_size
1193+
for (int32_t i = 0; i < bucket_size; ++i) {
1194+
equals[i] = false;
1195+
if (running) {
1196+
// inspect slot content
1197+
switch (this->predicate_.operator()<is_insert::NO>(
1198+
probe_key, this->extract_key(bucket_slots[i]))) {
1199+
case detail::equal_result::EMPTY: {
1200+
running = false;
1201+
break;
1202+
}
1203+
case detail::equal_result::EQUAL: {
1204+
if constexpr (!AllowsDuplicates) { running = false; }
1205+
equals[i] = true;
1206+
break;
1207+
}
1208+
default: {
1209+
break;
1210+
}
12031211
}
12041212
}
12051213
}
12061214

1207-
if (active_flushing_tile.any(match_found)) {
1208-
auto const matching_tile = cg::binary_partition(active_flushing_tile, match_found);
1209-
// stage matches in shmem buffer
1210-
if (match_found) {
1211-
buffers[flushing_tile_id][num_matches + matching_tile.thread_rank()] = {
1212-
probe_key, bucket_slots[i]};
1215+
probing_tile.sync();
1216+
running = probing_tile.all(running);
1217+
#pragma unroll buffer_size
1218+
for (int32_t i = 0; i < bucket_size; ++i) {
1219+
exists[i] = probing_tile.ballot(equals[i]);
1220+
}
1221+
1222+
if (thrust::any_of(thrust::seq, exists, exists + bucket_size, thrust::identity{})) {
1223+
if constexpr (IsOuter) { found_match = true; }
1224+
1225+
int32_t num_matches[bucket_size];
1226+
1227+
for (int32_t i = 0; i < bucket_size; ++i) {
1228+
num_matches[i] = __popc(exists[i]);
12131229
}
12141230

1215-
// add number of new matches to the buffer counter
1216-
num_matches += (match_found) ? matching_tile.size()
1217-
: active_flushing_tile.size() - matching_tile.size();
1218-
}
1231+
auto const total_matches =
1232+
thrust::reduce(thrust::seq, num_matches, num_matches + bucket_size);
12191233

1220-
if constexpr (IsOuter) {
1221-
if (not found_any_match /*yet*/ and probing_tile.any(match_found) /*now*/) {
1222-
found_any_match = true;
1234+
int32_t output_idx;
1235+
if (probing_tile.thread_rank() == 0) {
1236+
auto ref =
1237+
cuda::atomic_ref<int32_t, cuda::thread_scope_block>{counters[flushing_tile_id]};
1238+
output_idx = ref.fetch_add(total_matches, cuda::memory_order_relaxed);
1239+
}
1240+
output_idx = probing_tile.shfl(output_idx, 0);
1241+
1242+
int32_t matche_offset = 0;
1243+
#pragma unroll buffer_size
1244+
for (int32_t i = 0; i < bucket_size; ++i) {
1245+
if (equals[i]) {
1246+
auto const lane_offset =
1247+
detail::count_least_significant_bits(exists[i], probing_tile.thread_rank());
1248+
buffers[flushing_tile_id][output_idx + matche_offset + lane_offset] = {
1249+
probe_key, bucket_slots[i]};
1250+
}
1251+
matche_offset += num_matches[i];
12231252
}
12241253
}
12251254

1226-
// reset flag for next iteration
1227-
match_found = false;
1228-
}
1229-
running = probing_tile.all(running);
1230-
1231-
// check if all probing tiles have finished their work
1232-
bool const finished = !active_flushing_tile.any(running);
1233-
1234-
if constexpr (IsOuter) {
1235-
if (finished) {
1236-
bool const writes_sentinel =
1237-
((probing_tile.thread_rank() == 0) and not found_any_match);
1238-
1239-
auto const sentinel_writers =
1240-
cg::binary_partition(active_flushing_tile, writes_sentinel);
1241-
if (writes_sentinel) {
1242-
auto const rank = sentinel_writers.thread_rank();
1243-
buffers[flushing_tile_id][num_matches + rank] = {probe_key,
1244-
this->empty_slot_sentinel()};
1255+
if constexpr (IsOuter) {
1256+
if (!running) {
1257+
if (!found_match and probing_tile.thread_rank() == 0) {
1258+
auto ref =
1259+
cuda::atomic_ref<int32_t, cuda::thread_scope_block>{counters[flushing_tile_id]};
1260+
auto const output_idx = ref.fetch_add(1, cuda::memory_order_relaxed);
1261+
buffers[flushing_tile_id][output_idx] = {probe_key, this->empty_slot_sentinel()};
1262+
}
12451263
}
1246-
// add number of new matches to the buffer counter
1247-
num_matches += (writes_sentinel)
1248-
? sentinel_writers.size()
1249-
: active_flushing_tile.size() - sentinel_writers.size();
12501264
}
1251-
}
1265+
} // if running
12521266

1267+
active_flushing_tile.sync();
12531268
// if the buffer has not enough empty slots for the next iteration
1254-
if (num_matches > (buffer_size - max_matches_per_step)) {
1269+
if (counters[flushing_tile_id] > (buffer_size - max_matches_per_step)) {
12551270
flush_buffers(active_flushing_tile);
1271+
active_flushing_tile.sync();
12561272

12571273
// reset buffer counter
1258-
num_matches = 0;
1274+
if (active_flushing_tile.thread_rank() == 0) { counters[flushing_tile_id] = 0; }
1275+
active_flushing_tile.sync();
12591276
}
12601277

1261-
// the entire flushing tile has finished its work
1262-
if (finished) { break; }
1263-
12641278
// onto the next probing bucket
12651279
++probing_iter;
1266-
}
1267-
}
1280+
} // while running
1281+
} // if active_flag
12681282

12691283
// onto the next key
12701284
idx += stride;
12711285
}
12721286

1287+
flushing_tile.sync();
12731288
// entire flusing_tile has finished; flush remaining elements
1274-
if (num_matches > 0) { flush_buffers(flushing_tile); }
1289+
if (counters[flushing_tile_id] > 0) { flush_buffers(flushing_tile); }
12751290
}
12761291

12771292
/**

0 commit comments

Comments
 (0)