Skip to content

Commit 759b710

Browse files
authored
Add host find APIs taking custom key equality and hasher (#645)
This PR introduces host find APIs for the hash set, enabling queries using a different key_eq and hash function. It can be used to improve cudf distinct join performance.
1 parent 58d79ee commit 759b710

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed

include/cuco/detail/static_set/static_set.inl

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,29 @@ void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
337337
impl_->find_async(first, last, output_begin, ref(op::find), stream);
338338
}
339339

340+
template <class Key,
341+
class Extent,
342+
cuda::thread_scope Scope,
343+
class KeyEqual,
344+
class ProbingScheme,
345+
class Allocator,
346+
class Storage>
347+
template <typename InputIt, typename ProbeEqual, typename ProbeHash, typename OutputIt>
348+
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_async(
349+
InputIt first,
350+
InputIt last,
351+
ProbeEqual const& probe_equal,
352+
ProbeHash const& probe_hash,
353+
OutputIt output_begin,
354+
cuda::stream_ref stream) const
355+
{
356+
impl_->find_async(first,
357+
last,
358+
output_begin,
359+
ref(op::find).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash),
360+
stream);
361+
}
362+
340363
template <class Key,
341364
class Extent,
342365
cuda::thread_scope Scope,
@@ -376,6 +399,38 @@ void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>
376399
impl_->find_if_async(first, last, stencil, pred, output_begin, ref(op::find), stream);
377400
}
378401

402+
template <class Key,
403+
class Extent,
404+
cuda::thread_scope Scope,
405+
class KeyEqual,
406+
class ProbingScheme,
407+
class Allocator,
408+
class Storage>
409+
template <typename InputIt,
410+
typename StencilIt,
411+
typename Predicate,
412+
typename ProbeEqual,
413+
typename ProbeHash,
414+
typename OutputIt>
415+
void static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::find_if_async(
416+
InputIt first,
417+
InputIt last,
418+
StencilIt stencil,
419+
Predicate pred,
420+
ProbeEqual const& probe_equal,
421+
ProbeHash const& probe_hash,
422+
OutputIt output_begin,
423+
cuda::stream_ref stream) const
424+
{
425+
impl_->find_if_async(first,
426+
last,
427+
stencil,
428+
pred,
429+
output_begin,
430+
ref(op::find).rebind_key_eq(probe_equal).rebind_hash_function(probe_hash),
431+
stream);
432+
}
433+
379434
template <class Key,
380435
class Extent,
381436
cuda::thread_scope Scope,

include/cuco/static_set.cuh

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,34 @@ class static_set {
590590
OutputIt output_begin,
591591
cuda::stream_ref stream = {}) const;
592592

593+
/**
594+
* @brief For all keys in the range `[first, last)`, asynchronously finds an element with key
595+
* equivalent to the query key.
596+
*
597+
* @note If the key `*(first + i)` has a matched `element` in the set, copies `element` to
598+
* `(output_begin + i)`. Else, copies the empty key sentinel.
599+
*
600+
* @tparam InputIt Device accessible input iterator
601+
* @tparam ProbeEqual Binary callable equal type
602+
* @tparam ProbeHash Unary callable hasher type that can be constructed from
603+
* an integer value
604+
* @tparam OutputIt Device accessible output iterator assignable from the set's `key_type`
605+
*
606+
* @param first Beginning of the sequence of keys
607+
* @param last End of the sequence of keys
608+
* @param probe_equal The binary function to compare set keys and probe keys for equality
609+
* @param probe_hash The unary function to hash probe keys
610+
* @param output_begin Beginning of the sequence of elements retrieved for each key
611+
* @param stream Stream used for executing the kernels
612+
*/
613+
template <typename InputIt, typename ProbeEqual, typename ProbeHash, typename OutputIt>
614+
void find_async(InputIt first,
615+
InputIt last,
616+
ProbeEqual const& probe_equal,
617+
ProbeHash const& probe_hash,
618+
OutputIt output_begin,
619+
cuda::stream_ref stream = {}) const;
620+
593621
/**
594622
* @brief For all keys in the range `[first, last)`, finds a match with its key equivalent to the
595623
* query key.
@@ -654,6 +682,49 @@ class static_set {
654682
OutputIt output_begin,
655683
cuda::stream_ref stream = {}) const;
656684

685+
/**
686+
* @brief For all keys in the range `[first, last)`, asynchronously finds
687+
* a match with its key equivalent to the query key.
688+
*
689+
* @note If `pred( *(stencil + i) )` is true, stores the payload of the
690+
* matched key or the `empty_value_sentienl` to `(output_begin + i)`. If `pred( *(stencil + i) )`
691+
* is false, always stores the `empty_value_sentienl` to `(output_begin + i)`.
692+
*
693+
* @tparam InputIt Device accessible input iterator
694+
* @tparam StencilIt Device accessible random access iterator whose `value_type` is convertible to
695+
* Predicate's argument type
696+
* @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
697+
* argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
698+
* @tparam ProbeEqual Binary callable equal type
699+
* @tparam ProbeHash Unary callable hasher type that can be constructed from
700+
* an integer value
701+
* @tparam OutputIt Device accessible output iterator
702+
*
703+
* @param first Beginning of the sequence of keys
704+
* @param last End of the sequence of keys
705+
* @param stencil Beginning of the stencil sequence
706+
* @param pred Predicate to test on every element in the range `[stencil, stencil +
707+
* std::distance(first, last))`
708+
* @param probe_equal The binary function to compare set keys and probe keys for equality
709+
* @param probe_hash The unary function to hash probe keys
710+
* @param output_begin Beginning of the sequence of matches retrieved for each key
711+
* @param stream Stream used for executing the kernels
712+
*/
713+
template <typename InputIt,
714+
typename StencilIt,
715+
typename Predicate,
716+
typename ProbeEqual,
717+
typename ProbeHash,
718+
typename OutputIt>
719+
void find_if_async(InputIt first,
720+
InputIt last,
721+
StencilIt stencil,
722+
Predicate pred,
723+
ProbeEqual const& probe_equal,
724+
ProbeHash const& probe_hash,
725+
OutputIt output_begin,
726+
cuda::stream_ref stream = {}) const;
727+
657728
/**
658729
* @brief Applies the given function object `callback_op` to the copy of every filled slot in the
659730
* container

0 commit comments

Comments
 (0)