Skip to content

Commit 5b4a80e

Browse files
PointKernelsleeepyjackpre-commit-ci[bot]
authored
Add consistent for_each APIs for cuco hash tables (#632)
This PR adds host and device `for_each` APIs for all cuco hash tables. --------- Co-authored-by: Daniel Jünger <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b29b608 commit 5b4a80e

File tree

11 files changed

+685
-120
lines changed

11 files changed

+685
-120
lines changed

include/cuco/detail/static_map/static_map_ref.inl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@
2121
#include <cuco/operator.hpp>
2222

2323
#include <cuda/atomic>
24-
#include <cuda/std/functional>
2524
#include <cuda/std/type_traits>
26-
#include <thrust/tuple.h>
25+
#include <cuda/std/utility>
2726

2827
#include <cooperative_groups.h>
2928

@@ -1335,7 +1334,7 @@ class operator_impl<
13351334
{
13361335
// CRTP: cast `this` to the actual ref type
13371336
auto const& ref_ = static_cast<ref_type const&>(*this);
1338-
ref_.impl_.for_each(key, std::forward<CallbackOp>(callback_op));
1337+
ref_.impl_.for_each(key, cuda::std::forward<CallbackOp>(callback_op));
13391338
}
13401339

13411340
/**
@@ -1363,7 +1362,7 @@ class operator_impl<
13631362
{
13641363
// CRTP: cast `this` to the actual ref type
13651364
auto const& ref_ = static_cast<ref_type const&>(*this);
1366-
ref_.impl_.for_each(group, key, std::forward<CallbackOp>(callback_op));
1365+
ref_.impl_.for_each(group, key, cuda::std::forward<CallbackOp>(callback_op));
13671366
}
13681367
};
13691368

include/cuco/detail/static_multimap/static_multimap.inl

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,73 @@ void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator,
315315
impl_->find_async(first, last, output_begin, ref(op::find), stream);
316316
}
317317

318+
template <class Key,
319+
class T,
320+
class Extent,
321+
cuda::thread_scope Scope,
322+
class KeyEqual,
323+
class ProbingScheme,
324+
class Allocator,
325+
class Storage>
326+
template <typename CallbackOp>
327+
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
328+
CallbackOp&& callback_op, cuda::stream_ref stream) const
329+
{
330+
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
331+
stream.wait();
332+
}
333+
334+
template <class Key,
335+
class T,
336+
class Extent,
337+
cuda::thread_scope Scope,
338+
class KeyEqual,
339+
class ProbingScheme,
340+
class Allocator,
341+
class Storage>
342+
template <typename CallbackOp>
343+
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
344+
for_each_async(CallbackOp&& callback_op, cuda::stream_ref stream) const
345+
{
346+
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
347+
}
348+
349+
template <class Key,
350+
class T,
351+
class Extent,
352+
cuda::thread_scope Scope,
353+
class KeyEqual,
354+
class ProbingScheme,
355+
class Allocator,
356+
class Storage>
357+
template <typename InputIt, typename CallbackOp>
358+
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
359+
InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const
360+
{
361+
impl_->for_each_async(
362+
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
363+
stream.wait();
364+
}
365+
366+
template <class Key,
367+
class T,
368+
class Extent,
369+
cuda::thread_scope Scope,
370+
class KeyEqual,
371+
class ProbingScheme,
372+
class Allocator,
373+
class Storage>
374+
template <typename InputIt, typename CallbackOp>
375+
void static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
376+
for_each_async(InputIt first,
377+
InputIt last,
378+
CallbackOp&& callback_op,
379+
cuda::stream_ref stream) const noexcept
380+
{
381+
impl_->for_each_async(
382+
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
383+
}
384+
318385
template <class Key,
319386
class T,
320387
class Extent,

include/cuco/detail/static_multiset/static_multiset.inl

Lines changed: 106 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -284,17 +284,12 @@ template <class Key,
284284
class ProbingScheme,
285285
class Allocator,
286286
class Storage>
287-
template <class InputProbeIt, class OutputProbeIt, class OutputMatchIt>
288-
std::pair<OutputProbeIt, OutputMatchIt>
289-
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve(
290-
InputProbeIt first,
291-
InputProbeIt last,
292-
OutputProbeIt output_probe,
293-
OutputMatchIt output_match,
294-
cuda::stream_ref stream) const
287+
template <typename CallbackOp>
288+
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
289+
CallbackOp&& callback_op, cuda::stream_ref stream) const
295290
{
296-
return this->impl_->retrieve(
297-
first, last, output_probe, output_match, this->ref(op::retrieve), stream);
291+
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
292+
stream.wait();
298293
}
299294

300295
template <class Key,
@@ -304,24 +299,11 @@ template <class Key,
304299
class ProbingScheme,
305300
class Allocator,
306301
class Storage>
307-
template <class InputProbeIt,
308-
class ProbeEqual,
309-
class ProbeHash,
310-
class OutputProbeIt,
311-
class OutputMatchIt>
312-
std::pair<OutputProbeIt, OutputMatchIt>
313-
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve(
314-
InputProbeIt first,
315-
InputProbeIt last,
316-
ProbeEqual const& probe_equal,
317-
ProbeHash const& probe_hash,
318-
OutputProbeIt output_probe,
319-
OutputMatchIt output_match,
320-
cuda::stream_ref stream) const
302+
template <typename CallbackOp>
303+
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
304+
for_each_async(CallbackOp&& callback_op, cuda::stream_ref stream) const
321305
{
322-
auto const probe_ref =
323-
this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash);
324-
return this->impl_->retrieve(first, last, output_probe, output_match, probe_ref, stream);
306+
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
325307
}
326308

327309
template <class Key,
@@ -331,24 +313,31 @@ template <class Key,
331313
class ProbingScheme,
332314
class Allocator,
333315
class Storage>
334-
template <class InputProbeIt,
335-
class ProbeEqual,
336-
class ProbeHash,
337-
class OutputProbeIt,
338-
class OutputMatchIt>
339-
std::pair<OutputProbeIt, OutputMatchIt>
340-
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve_outer(
341-
InputProbeIt first,
342-
InputProbeIt last,
343-
ProbeEqual const& probe_equal,
344-
ProbeHash const& probe_hash,
345-
OutputProbeIt output_probe,
346-
OutputMatchIt output_match,
347-
cuda::stream_ref stream) const
316+
template <typename InputIt, typename CallbackOp>
317+
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
318+
InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const
348319
{
349-
auto const probe_ref =
350-
this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash);
351-
return this->impl_->retrieve_outer(first, last, output_probe, output_match, probe_ref, stream);
320+
impl_->for_each_async(
321+
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
322+
stream.wait();
323+
}
324+
325+
template <class Key,
326+
class Extent,
327+
cuda::thread_scope Scope,
328+
class KeyEqual,
329+
class ProbingScheme,
330+
class Allocator,
331+
class Storage>
332+
template <typename InputIt, typename CallbackOp>
333+
void static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
334+
for_each_async(InputIt first,
335+
InputIt last,
336+
CallbackOp&& callback_op,
337+
cuda::stream_ref stream) const noexcept
338+
{
339+
impl_->for_each_async(
340+
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
352341
}
353342

354343
template <class Key,
@@ -412,6 +401,79 @@ static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
412401
stream);
413402
}
414403

404+
template <class Key,
405+
class Extent,
406+
cuda::thread_scope Scope,
407+
class KeyEqual,
408+
class ProbingScheme,
409+
class Allocator,
410+
class Storage>
411+
template <class InputProbeIt, class OutputProbeIt, class OutputMatchIt>
412+
std::pair<OutputProbeIt, OutputMatchIt>
413+
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve(
414+
InputProbeIt first,
415+
InputProbeIt last,
416+
OutputProbeIt output_probe,
417+
OutputMatchIt output_match,
418+
cuda::stream_ref stream) const
419+
{
420+
return impl_->retrieve(first, last, output_probe, output_match, this->ref(op::retrieve), stream);
421+
}
422+
423+
template <class Key,
424+
class Extent,
425+
cuda::thread_scope Scope,
426+
class KeyEqual,
427+
class ProbingScheme,
428+
class Allocator,
429+
class Storage>
430+
template <class InputProbeIt,
431+
class ProbeEqual,
432+
class ProbeHash,
433+
class OutputProbeIt,
434+
class OutputMatchIt>
435+
std::pair<OutputProbeIt, OutputMatchIt>
436+
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve(
437+
InputProbeIt first,
438+
InputProbeIt last,
439+
ProbeEqual const& probe_equal,
440+
ProbeHash const& probe_hash,
441+
OutputProbeIt output_probe,
442+
OutputMatchIt output_match,
443+
cuda::stream_ref stream) const
444+
{
445+
auto const probe_ref =
446+
this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash);
447+
return impl_->retrieve(first, last, output_probe, output_match, probe_ref, stream);
448+
}
449+
450+
template <class Key,
451+
class Extent,
452+
cuda::thread_scope Scope,
453+
class KeyEqual,
454+
class ProbingScheme,
455+
class Allocator,
456+
class Storage>
457+
template <class InputProbeIt,
458+
class ProbeEqual,
459+
class ProbeHash,
460+
class OutputProbeIt,
461+
class OutputMatchIt>
462+
std::pair<OutputProbeIt, OutputMatchIt>
463+
static_multiset<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve_outer(
464+
InputProbeIt first,
465+
InputProbeIt last,
466+
ProbeEqual const& probe_equal,
467+
ProbeHash const& probe_hash,
468+
OutputProbeIt output_probe,
469+
OutputMatchIt output_match,
470+
cuda::stream_ref stream) const
471+
{
472+
auto const probe_ref =
473+
this->ref(op::retrieve).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash);
474+
return impl_->retrieve_outer(first, last, output_probe, output_match, probe_ref, stream);
475+
}
476+
415477
template <class Key,
416478
class Extent,
417479
cuda::thread_scope Scope,

include/cuco/detail/static_set/static_set.inl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,66 @@ void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
338338
impl_->find_async(first, last, output_begin, ref(op::find), stream);
339339
}
340340

341+
template <class Key,
342+
class Extent,
343+
cuda::thread_scope Scope,
344+
class KeyEqual,
345+
class ProbingScheme,
346+
class Allocator,
347+
class Storage>
348+
template <typename CallbackOp>
349+
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
350+
CallbackOp&& callback_op, cuda::stream_ref stream) const
351+
{
352+
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
353+
stream.wait();
354+
}
355+
356+
template <class Key,
357+
class Extent,
358+
cuda::thread_scope Scope,
359+
class KeyEqual,
360+
class ProbingScheme,
361+
class Allocator,
362+
class Storage>
363+
template <typename CallbackOp>
364+
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each_async(
365+
CallbackOp&& callback_op, cuda::stream_ref stream) const
366+
{
367+
impl_->for_each_async(std::forward<CallbackOp>(callback_op), stream);
368+
}
369+
370+
template <class Key,
371+
class Extent,
372+
cuda::thread_scope Scope,
373+
class KeyEqual,
374+
class ProbingScheme,
375+
class Allocator,
376+
class Storage>
377+
template <typename InputIt, typename CallbackOp>
378+
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each(
379+
InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const
380+
{
381+
impl_->for_each_async(
382+
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
383+
stream.wait();
384+
}
385+
386+
template <class Key,
387+
class Extent,
388+
cuda::thread_scope Scope,
389+
class KeyEqual,
390+
class ProbingScheme,
391+
class Allocator,
392+
class Storage>
393+
template <typename InputIt, typename CallbackOp>
394+
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::for_each_async(
395+
InputIt first, InputIt last, CallbackOp&& callback_op, cuda::stream_ref stream) const noexcept
396+
{
397+
impl_->for_each_async(
398+
first, last, std::forward<CallbackOp>(callback_op), ref(op::for_each), stream);
399+
}
400+
341401
template <class Key,
342402
class Extent,
343403
cuda::thread_scope Scope,

0 commit comments

Comments
 (0)