Skip to content

Commit 0b60143

Browse files
authored
Merge branch 'dev' into fix-full-occupancy
2 parents 740dbae + a4fb985 commit 0b60143

23 files changed

+596
-919
lines changed

include/cuco/detail/bloom_filter/bloom_filter_impl.cuh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ class bloom_filter_impl {
198198
auto const grid_size =
199199
cuco::detail::grid_size(num_keys, cg_size, cuco::detail::default_stride(), block_size);
200200

201-
detail::add_if_n<cg_size, block_size>
201+
detail::bloom_filter_ns::add_if_n<cg_size, block_size>
202202
<<<grid_size, block_size, 0, stream.get()>>>(first, num_keys, stencil, pred, *this);
203203
}
204204

@@ -303,8 +303,9 @@ class bloom_filter_impl {
303303
auto const grid_size =
304304
cuco::detail::grid_size(num_keys, cg_size, cuco::detail::default_stride(), block_size);
305305

306-
detail::contains_if_n<cg_size, block_size><<<grid_size, block_size, 0, stream.get()>>>(
307-
first, num_keys, stencil, pred, output_begin, *this);
306+
detail::bloom_filter_ns::contains_if_n<cg_size, block_size>
307+
<<<grid_size, block_size, 0, stream.get()>>>(
308+
first, num_keys, stencil, pred, output_begin, *this);
308309
}
309310

310311
[[nodiscard]] __host__ __device__ constexpr word_type* data() noexcept { return words_; }
@@ -365,4 +366,4 @@ class bloom_filter_impl {
365366
policy_type policy_;
366367
};
367368

368-
} // namespace cuco::detail
369+
} // namespace cuco::detail

include/cuco/detail/bloom_filter/kernels.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
#include <cstdint>
2323
#include <iterator>
2424

25-
namespace cuco::detail {
25+
namespace cuco::detail::bloom_filter_ns {
2626

2727
CUCO_SUPPRESS_KERNEL_WARNINGS
2828

@@ -89,4 +89,4 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void contains_if_n(InputIt first,
8989
}
9090
}
9191

92-
} // namespace cuco::detail
92+
} // namespace cuco::detail::bloom_filter_ns

include/cuco/detail/open_addressing/functors.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include <cuco/detail/bitwise_compare.cuh>
1919
#include <cuco/detail/pair/traits.hpp>
2020

21-
namespace cuco::open_addressing_ns::detail {
21+
namespace cuco::detail::open_addressing_ns {
2222

2323
/**
2424
* @brief Device functor returning the content of the slot indexed by `idx`
@@ -107,4 +107,4 @@ struct slot_is_filled {
107107
}
108108
};
109109

110-
} // namespace cuco::open_addressing_ns::detail
110+
} // namespace cuco::detail::open_addressing_ns

include/cuco/detail/open_addressing/kernels.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
#include <iterator>
2727

28-
namespace cuco::detail {
28+
namespace cuco::detail::open_addressing_ns {
2929
CUCO_SUPPRESS_KERNEL_WARNINGS
3030

3131
/**
@@ -729,4 +729,4 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void rehash(
729729
}
730730
}
731731

732-
} // namespace cuco::detail
732+
} // namespace cuco::detail::open_addressing_ns

include/cuco/detail/open_addressing/open_addressing_impl.cuh

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ class open_addressing_impl {
342342

343343
auto const grid_size = cuco::detail::grid_size(num_keys, cg_size);
344344

345-
detail::insert_if_n<cg_size, cuco::detail::default_block_size()>
345+
detail::open_addressing_ns::insert_if_n<cg_size, cuco::detail::default_block_size()>
346346
<<<grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
347347
first, num_keys, stencil, pred, counter.data(), container_ref);
348348

@@ -384,7 +384,7 @@ class open_addressing_impl {
384384

385385
auto const grid_size = cuco::detail::grid_size(num_keys, cg_size);
386386

387-
detail::insert_if_n<cg_size, cuco::detail::default_block_size()>
387+
detail::open_addressing_ns::insert_if_n<cg_size, cuco::detail::default_block_size()>
388388
<<<grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
389389
first, num_keys, stencil, pred, container_ref);
390390
}
@@ -426,7 +426,7 @@ class open_addressing_impl {
426426

427427
auto const grid_size = cuco::detail::grid_size(num_keys, cg_size);
428428

429-
detail::insert_and_find<cg_size, cuco::detail::default_block_size()>
429+
detail::open_addressing_ns::insert_and_find<cg_size, cuco::detail::default_block_size()>
430430
<<<grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
431431
first, num_keys, found_begin, inserted_begin, container_ref);
432432
}
@@ -466,7 +466,7 @@ class open_addressing_impl {
466466

467467
auto const grid_size = cuco::detail::grid_size(num_keys, cg_size);
468468

469-
detail::erase<cg_size, cuco::detail::default_block_size()>
469+
detail::open_addressing_ns::erase<cg_size, cuco::detail::default_block_size()>
470470
<<<grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
471471
first, num_keys, container_ref);
472472
}
@@ -540,7 +540,7 @@ class open_addressing_impl {
540540

541541
auto const grid_size = cuco::detail::grid_size(num_keys, cg_size);
542542

543-
detail::contains_if_n<cg_size, cuco::detail::default_block_size()>
543+
detail::open_addressing_ns::contains_if_n<cg_size, cuco::detail::default_block_size()>
544544
<<<grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
545545
first, num_keys, stencil, pred, output_begin, container_ref);
546546
}
@@ -615,7 +615,7 @@ class open_addressing_impl {
615615

616616
auto const grid_size = cuco::detail::grid_size(num_keys, cg_size);
617617

618-
detail::find_if_n<cg_size, cuco::detail::default_block_size()>
618+
detail::open_addressing_ns::find_if_n<cg_size, cuco::detail::default_block_size()>
619619
<<<grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
620620
first, num_keys, stencil, pred, output_begin, container_ref);
621621
}
@@ -789,8 +789,8 @@ class open_addressing_impl {
789789
std::min(static_cast<cuco::detail::index_type>(this->capacity()) - offset, stride);
790790
auto const begin = thrust::make_transform_iterator(
791791
thrust::counting_iterator{static_cast<size_type>(offset)},
792-
open_addressing_ns::detail::get_slot<has_payload, storage_ref_type>(this->storage_ref()));
793-
auto const is_filled = open_addressing_ns::detail::slot_is_filled<has_payload, key_type>{
792+
detail::open_addressing_ns::get_slot<has_payload, storage_ref_type>(this->storage_ref()));
793+
auto const is_filled = detail::open_addressing_ns::slot_is_filled<has_payload, key_type>{
794794
this->empty_key_sentinel(), this->erased_key_sentinel()};
795795

796796
std::size_t temp_storage_bytes = 0;
@@ -844,7 +844,7 @@ class open_addressing_impl {
844844
template <typename CallbackOp>
845845
void for_each_async(CallbackOp&& callback_op, cuda::stream_ref stream) const
846846
{
847-
auto const is_filled = open_addressing_ns::detail::slot_is_filled<has_payload, key_type>{
847+
auto const is_filled = detail::open_addressing_ns::slot_is_filled<has_payload, key_type>{
848848
this->empty_key_sentinel(), this->erased_key_sentinel()};
849849

850850
auto storage_ref = this->storage_ref();
@@ -886,7 +886,7 @@ class open_addressing_impl {
886886

887887
auto const grid_size = cuco::detail::grid_size(num_keys, cg_size);
888888

889-
detail::for_each_n<cg_size, cuco::detail::default_block_size()>
889+
detail::open_addressing_ns::for_each_n<cg_size, cuco::detail::default_block_size()>
890890
<<<grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
891891
first, num_keys, std::forward<CallbackOp>(callback_op), container_ref);
892892
}
@@ -907,12 +907,12 @@ class open_addressing_impl {
907907
counter.reset(stream);
908908

909909
auto const grid_size = cuco::detail::grid_size(storage_.num_buckets());
910-
auto const is_filled = open_addressing_ns::detail::slot_is_filled<has_payload, key_type>{
910+
auto const is_filled = detail::open_addressing_ns::slot_is_filled<has_payload, key_type>{
911911
this->empty_key_sentinel(), this->erased_key_sentinel()};
912912

913913
// TODO: custom kernel to be replaced by cub::DeviceReduce::Sum when cub version is bumped to
914914
// v2.1.0
915-
detail::size<cuco::detail::default_block_size()>
915+
detail::open_addressing_ns::size<cuco::detail::default_block_size()>
916916
<<<grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
917917
storage_.ref(), is_filled, counter.data());
918918

@@ -1014,10 +1014,10 @@ class open_addressing_impl {
10141014
auto constexpr block_size = cuco::detail::default_block_size();
10151015
auto constexpr stride = cuco::detail::default_stride();
10161016
auto const grid_size = cuco::detail::grid_size(num_buckets, 1, stride, block_size);
1017-
auto const is_filled = open_addressing_ns::detail::slot_is_filled<has_payload, key_type>{
1017+
auto const is_filled = detail::open_addressing_ns::slot_is_filled<has_payload, key_type>{
10181018
this->empty_key_sentinel(), this->erased_key_sentinel()};
10191019

1020-
detail::rehash<block_size><<<grid_size, block_size, 0, stream.get()>>>(
1020+
detail::open_addressing_ns::rehash<block_size><<<grid_size, block_size, 0, stream.get()>>>(
10211021
old_storage.ref(), container.ref(op::insert), is_filled);
10221022
}
10231023

@@ -1120,7 +1120,7 @@ class open_addressing_impl {
11201120

11211121
auto const grid_size = cuco::detail::grid_size(num_keys, cg_size);
11221122

1123-
detail::count<IsOuter, cg_size, cuco::detail::default_block_size()>
1123+
detail::open_addressing_ns::count<IsOuter, cg_size, cuco::detail::default_block_size()>
11241124
<<<grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
11251125
first, num_keys, counter.data(), container_ref);
11261126

@@ -1180,8 +1180,9 @@ class open_addressing_impl {
11801180
auto constexpr grid_stride = 1;
11811181
auto const grid_size = cuco::detail::grid_size(n, cg_size, grid_stride, block_size);
11821182

1183-
detail::retrieve<IsOuter, block_size><<<grid_size, block_size, 0, stream.get()>>>(
1184-
first, n, output_probe, output_match, counter.data(), container_ref);
1183+
detail::open_addressing_ns::retrieve<IsOuter, block_size>
1184+
<<<grid_size, block_size, 0, stream.get()>>>(
1185+
first, n, output_probe, output_match, counter.data(), container_ref);
11851186

11861187
auto const num_retrieved = counter.load_to_host(stream.get());
11871188

include/cuco/detail/static_map/helpers.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include <cuco/detail/static_map/kernels.cuh>
1919
#include <cuco/detail/utility/cuda.cuh>
2020

21-
namespace cuco::static_map_ns::detail {
21+
namespace cuco::detail::static_map_ns {
2222

2323
/**
2424
* @brief Dispatches to shared memory map kernel if `num_elements_per_thread > 2`, else
@@ -112,4 +112,4 @@ void dispatch_insert_or_apply(
112112
first, num, init, op, ref);
113113
}
114114
}
115-
} // namespace cuco::static_map_ns::detail
115+
} // namespace cuco::detail::static_map_ns

include/cuco/detail/static_map/kernels.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
#include <iterator>
2828

29-
namespace cuco::static_map_ns::detail {
29+
namespace cuco::detail::static_map_ns {
3030
CUCO_SUPPRESS_KERNEL_WARNINGS
3131

3232
// TODO user insert_or_assign internally
@@ -262,4 +262,4 @@ CUCO_KERNEL __launch_bounds__(BlockSize) void insert_or_apply_shmem(
262262
}
263263
}
264264
}
265-
} // namespace cuco::static_map_ns::detail
265+
} // namespace cuco::detail::static_map_ns

include/cuco/detail/static_map/static_map.inl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
284284

285285
auto const grid_size = cuco::detail::grid_size(num, cg_size);
286286

287-
static_map_ns::detail::insert_or_assign<cg_size, cuco::detail::default_block_size()>
287+
detail::static_map_ns::insert_or_assign<cg_size, cuco::detail::default_block_size()>
288288
<<<grid_size, cuco::detail::default_block_size(), 0, stream.get()>>>(
289289
first, num, ref(op::insert_or_assign));
290290
}
@@ -335,7 +335,7 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
335335
{
336336
auto constexpr has_init = false;
337337
auto const init = this->empty_value_sentinel(); // use empty_sentinel as unused init value
338-
static_map_ns::detail::dispatch_insert_or_apply<has_init, cg_size, Allocator>(
338+
detail::static_map_ns::dispatch_insert_or_apply<has_init, cg_size, Allocator>(
339339
first, last, init, op, ref(op::insert_or_apply), stream);
340340
}
341341

@@ -353,7 +353,7 @@ void static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
353353
InputIt first, InputIt last, Init init, Op op, cuda::stream_ref stream) noexcept
354354
{
355355
auto constexpr has_init = true;
356-
static_map_ns::detail::dispatch_insert_or_apply<has_init, cg_size, Allocator>(
356+
detail::static_map_ns::dispatch_insert_or_apply<has_init, cg_size, Allocator>(
357357
first, last, init, op, ref(op::insert_or_apply), stream);
358358
}
359359

@@ -612,6 +612,26 @@ static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
612612
return impl_->count(first, last, ref(op::count), stream);
613613
}
614614

615+
template <class Key,
616+
class T,
617+
class Extent,
618+
cuda::thread_scope Scope,
619+
class KeyEqual,
620+
class ProbingScheme,
621+
class Allocator,
622+
class Storage>
623+
template <typename InputIt, typename OutputProbeIt, typename OutputMatchIt>
624+
std::pair<OutputProbeIt, OutputMatchIt>
625+
static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve(
626+
InputIt first,
627+
InputIt last,
628+
OutputProbeIt output_probe,
629+
OutputMatchIt output_match,
630+
cuda::stream_ref stream) const
631+
{
632+
return impl_->retrieve(first, last, output_probe, output_match, this->ref(op::retrieve), stream);
633+
}
634+
615635
template <class Key,
616636
class T,
617637
class Extent,

include/cuco/detail/static_map/static_map_ref.inl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1428,5 +1428,73 @@ class operator_impl<
14281428
return ref_.impl_.count(group, key);
14291429
}
14301430
};
1431+
1432+
template <typename Key,
1433+
typename T,
1434+
cuda::thread_scope Scope,
1435+
typename KeyEqual,
1436+
typename ProbingScheme,
1437+
typename StorageRef,
1438+
typename... Operators>
1439+
class operator_impl<
1440+
op::retrieve_tag,
1441+
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>> {
1442+
using base_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef>;
1443+
using ref_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
1444+
using key_type = typename base_type::key_type;
1445+
using value_type = typename base_type::value_type;
1446+
using iterator = typename base_type::iterator;
1447+
using const_iterator = typename base_type::const_iterator;
1448+
1449+
static constexpr auto cg_size = base_type::cg_size;
1450+
static constexpr auto bucket_size = base_type::bucket_size;
1451+
1452+
public:
1453+
/**
1454+
* @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin,
1455+
* input_probe_end)`.
1456+
*
1457+
* If key `k = *(first + i)` exists in the container, copies `k` to `output_probe` and associated
1458+
* slot content to `output_match`, respectively. The output order is unspecified.
1459+
*
1460+
* Behavior is undefined if the size of the output range exceeds the number of retrieved slots.
1461+
* Use `count()` to determine the size of the output range.
1462+
*
1463+
* @tparam BlockSize Size of the thread block this operation is executed in
1464+
* @tparam InputProbeIt Device accessible input iterator whose `value_type` is
1465+
* convertible to the container's `key_type`
1466+
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
1467+
* convertible to the container's `key_type`
1468+
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
1469+
* convertible to the container's `value_type`
1470+
* @tparam AtomicCounter Atomic counter type that follows the same semantics as
1471+
* `cuda::atomic(_ref)`
1472+
*
1473+
* @param block Thread block this operation is executed in
1474+
* @param input_probe_begin Beginning of the input sequence of keys
1475+
* @param input_probe_end End of the input sequence of keys
1476+
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
1477+
* `output_match`
1478+
* @param output_match Beginning of the sequence of matching elements
1479+
* @param atomic_counter Counter that is used to determine the next free position in the output
1480+
* sequences
1481+
*/
1482+
template <int32_t BlockSize,
1483+
class InputProbeIt,
1484+
class OutputProbeIt,
1485+
class OutputMatchIt,
1486+
class AtomicCounter>
1487+
__device__ void retrieve(cooperative_groups::thread_block const& block,
1488+
InputProbeIt input_probe_begin,
1489+
InputProbeIt input_probe_end,
1490+
OutputProbeIt output_probe,
1491+
OutputMatchIt output_match,
1492+
AtomicCounter* atomic_counter) const
1493+
{
1494+
auto const& ref_ = static_cast<ref_type const&>(*this);
1495+
ref_.impl_.retrieve<BlockSize>(
1496+
block, input_probe_begin, input_probe_end, output_probe, output_match, atomic_counter);
1497+
}
1498+
};
14311499
} // namespace detail
14321500
} // namespace cuco

include/cuco/detail/static_multimap/static_multimap.inl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,26 @@ static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
439439
return impl_->count(first, last, ref(op::count), stream);
440440
}
441441

442+
template <class Key,
443+
class T,
444+
class Extent,
445+
cuda::thread_scope Scope,
446+
class KeyEqual,
447+
class ProbingScheme,
448+
class Allocator,
449+
class Storage>
450+
template <typename InputIt, typename OutputProbeIt, typename OutputMatchIt>
451+
std::pair<OutputProbeIt, OutputMatchIt>
452+
static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve(
453+
InputIt first,
454+
InputIt last,
455+
OutputProbeIt output_probe,
456+
OutputMatchIt output_match,
457+
cuda::stream_ref stream) const
458+
{
459+
return impl_->retrieve(first, last, output_probe, output_match, this->ref(op::retrieve), stream);
460+
}
461+
442462
template <class Key,
443463
class T,
444464
class Extent,

0 commit comments

Comments
 (0)