Skip to content

Commit 9ef3535

Browse files
authored
Add static_multimap_ref::for_each (#599)
1 parent 4454de4 commit 9ef3535

File tree

3 files changed

+285
-1
lines changed

3 files changed

+285
-1
lines changed

include/cuco/detail/static_multimap/static_multimap_ref.inl

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <cuda/atomic>
2323
#include <cuda/std/functional>
24+
#include <cuda/std/utility>
2425

2526
#include <cooperative_groups.h>
2627

@@ -487,6 +488,115 @@ class operator_impl<
487488
}
488489
};
489490

491+
template <typename Key,
492+
typename T,
493+
cuda::thread_scope Scope,
494+
typename KeyEqual,
495+
typename ProbingScheme,
496+
typename StorageRef,
497+
typename... Operators>
498+
class operator_impl<
499+
op::for_each_tag,
500+
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>> {
501+
using base_type = static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef>;
502+
using ref_type =
503+
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
504+
505+
static constexpr auto cg_size = base_type::cg_size;
506+
507+
public:
508+
/**
509+
* @brief Executes a callback on every element in the container with key equivalent to the probe
510+
* key.
511+
*
512+
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
513+
* `key` to the callback.
514+
*
515+
* @tparam ProbeKey Probe key type
516+
* @tparam CallbackOp Unary callback functor or device lambda
517+
*
518+
* @param key The key to search for
519+
* @param callback_op Function to call on every element found
520+
*/
521+
template <class ProbeKey, class CallbackOp>
522+
__device__ void for_each(ProbeKey const& key, CallbackOp&& callback_op) const noexcept
523+
{
524+
// CRTP: cast `this` to the actual ref type
525+
auto const& ref_ = static_cast<ref_type const&>(*this);
526+
ref_.impl_.for_each(key, cuda::std::forward<CallbackOp>(callback_op));
527+
}
528+
529+
/**
530+
* @brief Executes a callback on every element in the container with key equivalent to the probe
531+
* key.
532+
*
533+
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
534+
* `key` to the callback.
535+
*
536+
* @note This function uses cooperative group semantics, meaning that any thread may call the
537+
* callback if it finds a matching element. If multiple elements are found within the same group,
538+
* each thread with a match will call the callback with its associated element.
539+
*
540+
* @note Synchronizing `group` within `callback_op` is undefined behavior.
541+
*
542+
* @tparam ProbeKey Probe key type
543+
* @tparam CallbackOp Unary callback functor or device lambda
544+
*
545+
* @param group The Cooperative Group used to perform this operation
546+
* @param key The key to search for
547+
* @param callback_op Function to call on every element found
548+
*/
549+
template <class ProbeKey, class CallbackOp>
550+
__device__ void for_each(cooperative_groups::thread_block_tile<cg_size> const& group,
551+
ProbeKey const& key,
552+
CallbackOp&& callback_op) const noexcept
553+
{
554+
// CRTP: cast `this` to the actual ref type
555+
auto const& ref_ = static_cast<ref_type const&>(*this);
556+
ref_.impl_.for_each(group, key, cuda::std::forward<CallbackOp>(callback_op));
557+
}
558+
559+
/**
560+
* @brief Executes a callback on every element in the container with key equivalent to the probe
561+
* key and can additionally perform work that requires synchronizing the Cooperative Group
562+
* performing this operation.
563+
*
564+
* @note Passes an un-incrementable input iterator to the element whose key is equivalent to
565+
* `key` to the callback.
566+
*
567+
* @note This function uses cooperative group semantics, meaning that any thread may call the
568+
* callback if it finds a matching element. If multiple elements are found within the same group,
569+
* each thread with a match will call the callback with its associated element.
570+
*
571+
* @note Synchronizing `group` within `callback_op` is undefined behavior.
572+
*
573+
* @note The `sync_op` function can be used to perform work that requires synchronizing threads in
574+
* `group` inbetween probing steps, where the number of probing steps performed between
575+
* synchronization points is capped by `window_size * cg_size`. The functor will be called right
576+
* after the current probing window has been traversed.
577+
*
578+
* @tparam ProbeKey Probe key type
579+
* @tparam CallbackOp Unary callback functor or device lambda
580+
* @tparam SyncOp Functor or device lambda which accepts the current `group` object
581+
*
582+
* @param group The Cooperative Group used to perform this operation
583+
* @param key The key to search for
584+
* @param callback_op Function to call on every element found
585+
* @param sync_op Function that is allowed to synchronize `group` inbetween probing windows
586+
*/
587+
template <class ProbeKey, class CallbackOp, class SyncOp>
588+
__device__ void for_each(cooperative_groups::thread_block_tile<cg_size> const& group,
589+
ProbeKey const& key,
590+
CallbackOp&& callback_op,
591+
SyncOp&& sync_op) const noexcept
592+
{
593+
// CRTP: cast `this` to the actual ref type
594+
auto const& ref_ = static_cast<ref_type const&>(*this);
595+
ref_.impl_.for_each(
596+
group, key, cuda::std::forward<CallbackOp>(callback_op), cuda::std::forward<SyncOp>(sync_op));
597+
}
598+
};
599+
490600
template <typename Key,
491601
typename T,
492602
cuda::thread_scope Scope,

tests/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ ConfigureTest(STATIC_MULTIMAP_TEST
117117
static_multimap/insert_if_test.cu
118118
static_multimap/multiplicity_test.cu
119119
static_multimap/non_match_test.cu
120-
static_multimap/pair_function_test.cu)
120+
static_multimap/pair_function_test.cu
121+
static_multimap/for_each_test.cu)
121122

122123
###################################################################################################
123124
# - dynamic_bitset tests --------------------------------------------------------------------------
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include <test_utils.hpp>
18+
19+
#include <cuco/detail/utility/cuda.hpp>
20+
#include <cuco/static_multimap.cuh>
21+
22+
#include <cuda/atomic>
23+
#include <cuda/functional>
24+
#include <thrust/iterator/counting_iterator.h>
25+
#include <thrust/iterator/transform_iterator.h>
26+
27+
#include <cooperative_groups.h>
28+
#include <cooperative_groups/reduce.h>
29+
30+
#include <catch2/catch_template_test_macros.hpp>
31+
32+
#include <cstddef>
33+
34+
template <class Ref, class InputIt, class AtomicErrorCounter>
35+
CUCO_KERNEL void for_each_check_scalar(Ref ref,
36+
InputIt first,
37+
std::size_t n,
38+
std::size_t multiplicity,
39+
AtomicErrorCounter* error_counter)
40+
{
41+
static_assert(Ref::cg_size == 1, "Scalar test must have cg_size==1");
42+
auto const loop_stride = cuco::detail::grid_stride();
43+
auto idx = cuco::detail::global_thread_id();
44+
45+
while (idx < n) {
46+
auto const& key = *(first + idx);
47+
std::size_t matches = 0;
48+
ref.for_each(key, [&] __device__(auto const slot) {
49+
auto const [slot_key, slot_value] = slot;
50+
if (ref.key_eq()(key, slot_key) and ref.key_eq()(slot_key, slot_value)) { matches++; }
51+
});
52+
if (matches != multiplicity) { error_counter->fetch_add(1, cuda::memory_order_relaxed); }
53+
idx += loop_stride;
54+
}
55+
}
56+
57+
template <bool Synced, class Ref, class InputIt, class AtomicErrorCounter>
58+
CUCO_KERNEL void for_each_check_cooperative(Ref ref,
59+
InputIt first,
60+
std::size_t n,
61+
std::size_t multiplicity,
62+
AtomicErrorCounter* error_counter)
63+
{
64+
auto const loop_stride = cuco::detail::grid_stride() / Ref::cg_size;
65+
auto idx = cuco::detail::global_thread_id() / Ref::cg_size;
66+
;
67+
68+
while (idx < n) {
69+
auto const tile =
70+
cooperative_groups::tiled_partition<Ref::cg_size>(cooperative_groups::this_thread_block());
71+
auto const& key = *(first + idx);
72+
std::size_t thread_matches = 0;
73+
if constexpr (Synced) {
74+
ref.for_each(
75+
tile,
76+
key,
77+
[&] __device__(auto const slot) {
78+
auto const [slot_key, slot_value] = slot;
79+
if (ref.key_eq()(key, slot_key) and ref.key_eq()(slot_key, slot_value)) {
80+
thread_matches++;
81+
}
82+
},
83+
[] __device__(auto const& group) { group.sync(); });
84+
} else {
85+
ref.for_each(tile, key, [&] __device__(auto const slot) {
86+
auto const [slot_key, slot_value] = slot;
87+
if (ref.key_eq()(key, slot_key) and ref.key_eq()(slot_key, slot_value)) {
88+
thread_matches++;
89+
}
90+
});
91+
}
92+
auto const tile_matches =
93+
cooperative_groups::reduce(tile, thread_matches, cooperative_groups::plus<std::size_t>());
94+
if (tile_matches != multiplicity and tile.thread_rank() == 0) {
95+
error_counter->fetch_add(1, cuda::memory_order_relaxed);
96+
}
97+
idx += loop_stride;
98+
}
99+
}
100+
101+
TEMPLATE_TEST_CASE_SIG(
102+
"static_multimap for_each tests",
103+
"",
104+
((typename Key, cuco::test::probe_sequence Probe, int CGSize), Key, Probe, CGSize),
105+
(int32_t, cuco::test::probe_sequence::double_hashing, 1),
106+
(int32_t, cuco::test::probe_sequence::double_hashing, 2),
107+
(int64_t, cuco::test::probe_sequence::double_hashing, 1),
108+
(int64_t, cuco::test::probe_sequence::double_hashing, 2),
109+
(int32_t, cuco::test::probe_sequence::linear_probing, 1),
110+
(int32_t, cuco::test::probe_sequence::linear_probing, 2),
111+
(int64_t, cuco::test::probe_sequence::linear_probing, 1),
112+
(int64_t, cuco::test::probe_sequence::linear_probing, 2))
113+
{
114+
constexpr size_t num_unique_keys{400};
115+
constexpr size_t key_multiplicity{5};
116+
constexpr size_t num_keys{num_unique_keys * key_multiplicity};
117+
118+
using probe = std::conditional_t<Probe == cuco::test::probe_sequence::linear_probing,
119+
cuco::linear_probing<CGSize, cuco::default_hash_function<Key>>,
120+
cuco::double_hashing<CGSize, cuco::default_hash_function<Key>>>;
121+
122+
auto set = cuco::experimental::static_multimap{num_keys,
123+
cuco::empty_key<Key>{-1},
124+
cuco::empty_value<Key>{-1},
125+
{},
126+
probe{},
127+
{},
128+
cuco::storage<2>{}};
129+
130+
auto unique_keys_begin = thrust::counting_iterator<Key>(0);
131+
auto gen_duplicate_keys = cuda::proclaim_return_type<Key>(
132+
[] __device__(auto const& k) { return static_cast<Key>(k % num_unique_keys); });
133+
auto keys_begin = thrust::make_transform_iterator(unique_keys_begin, gen_duplicate_keys);
134+
135+
auto const pairs_begin = thrust::make_transform_iterator(
136+
keys_begin, cuda::proclaim_return_type<cuco::pair<Key, Key>>([] __device__(auto i) {
137+
return cuco::pair<Key, Key>{i, i};
138+
}));
139+
140+
set.insert(pairs_begin, pairs_begin + num_keys);
141+
142+
using error_counter_type = cuda::atomic<std::size_t, cuda::thread_scope_system>;
143+
error_counter_type* error_counter;
144+
CUCO_CUDA_TRY(cudaMallocHost(&error_counter, sizeof(error_counter_type)));
145+
new (error_counter) error_counter_type{0};
146+
147+
auto const grid_size = cuco::detail::grid_size(num_unique_keys, CGSize);
148+
auto const block_size = cuco::detail::default_block_size();
149+
150+
// test scalar for_each
151+
if constexpr (CGSize == 1) {
152+
for_each_check_scalar<<<grid_size, block_size>>>(
153+
set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter);
154+
CUCO_CUDA_TRY(cudaDeviceSynchronize());
155+
REQUIRE(error_counter->load() == 0);
156+
error_counter->store(0);
157+
}
158+
159+
// test CG for_each
160+
for_each_check_cooperative<false><<<grid_size, block_size>>>(
161+
set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter);
162+
CUCO_CUDA_TRY(cudaDeviceSynchronize());
163+
REQUIRE(error_counter->load() == 0);
164+
error_counter->store(0);
165+
166+
// test synchronized CG for_each
167+
for_each_check_cooperative<true><<<grid_size, block_size>>>(
168+
set.ref(cuco::for_each), unique_keys_begin, num_unique_keys, key_multiplicity, error_counter);
169+
CUCO_CUDA_TRY(cudaDeviceSynchronize());
170+
REQUIRE(error_counter->load() == 0);
171+
172+
CUCO_CUDA_TRY(cudaFreeHost(error_counter));
173+
}

0 commit comments

Comments
 (0)