Skip to content

Commit a4fb985

Browse files
PointKernelpre-commit-ci[bot]sleeepyjack
authored
Add retrieve for new map/multimap (#643)
This PR introduces `retrieve` APIs for map and multimap, while also removing the outdated and odd `pair_count/retrieve` tests. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Daniel Jünger <[email protected]>
1 parent 759b710 commit a4fb985

File tree

15 files changed

+433
-884
lines changed

15 files changed

+433
-884
lines changed

include/cuco/detail/static_map/static_map.inl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,26 @@ static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::
612612
return impl_->count(first, last, ref(op::count), stream);
613613
}
614614

615+
template <class Key,
616+
class T,
617+
class Extent,
618+
cuda::thread_scope Scope,
619+
class KeyEqual,
620+
class ProbingScheme,
621+
class Allocator,
622+
class Storage>
623+
template <typename InputIt, typename OutputProbeIt, typename OutputMatchIt>
624+
std::pair<OutputProbeIt, OutputMatchIt>
625+
static_map<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve(
626+
InputIt first,
627+
InputIt last,
628+
OutputProbeIt output_probe,
629+
OutputMatchIt output_match,
630+
cuda::stream_ref stream) const
631+
{
632+
return impl_->retrieve(first, last, output_probe, output_match, this->ref(op::retrieve), stream);
633+
}
634+
615635
template <class Key,
616636
class T,
617637
class Extent,

include/cuco/detail/static_map/static_map_ref.inl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,5 +1420,73 @@ class operator_impl<
14201420
return ref_.impl_.count(group, key);
14211421
}
14221422
};
1423+
1424+
template <typename Key,
1425+
typename T,
1426+
cuda::thread_scope Scope,
1427+
typename KeyEqual,
1428+
typename ProbingScheme,
1429+
typename StorageRef,
1430+
typename... Operators>
1431+
class operator_impl<
1432+
op::retrieve_tag,
1433+
static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>> {
1434+
using base_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef>;
1435+
using ref_type = static_map_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
1436+
using key_type = typename base_type::key_type;
1437+
using value_type = typename base_type::value_type;
1438+
using iterator = typename base_type::iterator;
1439+
using const_iterator = typename base_type::const_iterator;
1440+
1441+
static constexpr auto cg_size = base_type::cg_size;
1442+
static constexpr auto bucket_size = base_type::bucket_size;
1443+
1444+
public:
1445+
/**
1446+
* @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin,
1447+
* input_probe_end)`.
1448+
*
1449+
* If key `k = *(first + i)` exists in the container, copies `k` to `output_probe` and associated
1450+
* slot content to `output_match`, respectively. The output order is unspecified.
1451+
*
1452+
* Behavior is undefined if the size of the output range exceeds the number of retrieved slots.
1453+
* Use `count()` to determine the size of the output range.
1454+
*
1455+
* @tparam BlockSize Size of the thread block this operation is executed in
1456+
* @tparam InputProbeIt Device accessible input iterator whose `value_type` is
1457+
* convertible to the container's `key_type`
1458+
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
1459+
* convertible to the container's `key_type`
1460+
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
1461+
* convertible to the container's `value_type`
1462+
* @tparam AtomicCounter Atomic counter type that follows the same semantics as
1463+
* `cuda::atomic(_ref)`
1464+
*
1465+
* @param block Thread block this operation is executed in
1466+
* @param input_probe_begin Beginning of the input sequence of keys
1467+
* @param input_probe_end End of the input sequence of keys
1468+
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
1469+
* `output_match`
1470+
* @param output_match Beginning of the sequence of matching elements
1471+
* @param atomic_counter Counter that is used to determine the next free position in the output
1472+
* sequences
1473+
*/
1474+
template <int32_t BlockSize,
1475+
class InputProbeIt,
1476+
class OutputProbeIt,
1477+
class OutputMatchIt,
1478+
class AtomicCounter>
1479+
__device__ void retrieve(cooperative_groups::thread_block const& block,
1480+
InputProbeIt input_probe_begin,
1481+
InputProbeIt input_probe_end,
1482+
OutputProbeIt output_probe,
1483+
OutputMatchIt output_match,
1484+
AtomicCounter* atomic_counter) const
1485+
{
1486+
auto const& ref_ = static_cast<ref_type const&>(*this);
1487+
ref_.impl_.retrieve<BlockSize>(
1488+
block, input_probe_begin, input_probe_end, output_probe, output_match, atomic_counter);
1489+
}
1490+
};
14231491
} // namespace detail
14241492
} // namespace cuco

include/cuco/detail/static_multimap/static_multimap.inl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,26 @@ static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Stora
439439
return impl_->count(first, last, ref(op::count), stream);
440440
}
441441

442+
template <class Key,
443+
class T,
444+
class Extent,
445+
cuda::thread_scope Scope,
446+
class KeyEqual,
447+
class ProbingScheme,
448+
class Allocator,
449+
class Storage>
450+
template <typename InputIt, typename OutputProbeIt, typename OutputMatchIt>
451+
std::pair<OutputProbeIt, OutputMatchIt>
452+
static_multimap<Key, T, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve(
453+
InputIt first,
454+
InputIt last,
455+
OutputProbeIt output_probe,
456+
OutputMatchIt output_match,
457+
cuda::stream_ref stream) const
458+
{
459+
return impl_->retrieve(first, last, output_probe, output_match, this->ref(op::retrieve), stream);
460+
}
461+
442462
template <class Key,
443463
class T,
444464
class Extent,

include/cuco/detail/static_multimap/static_multimap_ref.inl

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,5 +750,73 @@ class operator_impl<
750750
}
751751
};
752752

753+
template <typename Key,
754+
typename T,
755+
cuda::thread_scope Scope,
756+
typename KeyEqual,
757+
typename ProbingScheme,
758+
typename StorageRef,
759+
typename... Operators>
760+
class operator_impl<
761+
op::retrieve_tag,
762+
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>> {
763+
using base_type = static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef>;
764+
using ref_type =
765+
static_multimap_ref<Key, T, Scope, KeyEqual, ProbingScheme, StorageRef, Operators...>;
766+
using key_type = typename base_type::key_type;
767+
using value_type = typename base_type::value_type;
768+
using iterator = typename base_type::iterator;
769+
using const_iterator = typename base_type::const_iterator;
770+
771+
static constexpr auto cg_size = base_type::cg_size;
772+
static constexpr auto bucket_size = base_type::bucket_size;
773+
774+
public:
775+
/**
776+
* @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin,
777+
* input_probe_end)`.
778+
*
779+
* If key `k = *(first + i)` exists in the container, copies `k` to `output_probe` and associated
780+
* slot content to `output_match`, respectively. The output order is unspecified.
781+
*
782+
* Behavior is undefined if the size of the output range exceeds the number of retrieved slots.
783+
* Use `count()` to determine the size of the output range.
784+
*
785+
* @tparam BlockSize Size of the thread block this operation is executed in
786+
* @tparam InputProbeIt Device accessible input iterator whose `value_type` is
787+
* convertible to the container's `key_type`
788+
* @tparam OutputProbeIt Device accessible input iterator whose `value_type` is
789+
* convertible to the container's `key_type`
790+
* @tparam OutputMatchIt Device accessible input iterator whose `value_type` is
791+
* convertible to the container's `value_type`
792+
* @tparam AtomicCounter Atomic counter type that follows the same semantics as
793+
* `cuda::atomic(_ref)`
794+
*
795+
* @param block Thread block this operation is executed in
796+
* @param input_probe_begin Beginning of the input sequence of keys
797+
* @param input_probe_end End of the input sequence of keys
798+
* @param output_probe Beginning of the sequence of keys corresponding to matching elements in
799+
* `output_match`
800+
* @param output_match Beginning of the sequence of matching elements
801+
* @param atomic_counter Counter that is used to determine the next free position in the output
802+
* sequences
803+
*/
804+
template <int32_t BlockSize,
805+
class InputProbeIt,
806+
class OutputProbeIt,
807+
class OutputMatchIt,
808+
class AtomicCounter>
809+
__device__ void retrieve(cooperative_groups::thread_block const& block,
810+
InputProbeIt input_probe_begin,
811+
InputProbeIt input_probe_end,
812+
OutputProbeIt output_probe,
813+
OutputMatchIt output_match,
814+
AtomicCounter* atomic_counter) const
815+
{
816+
auto const& ref_ = static_cast<ref_type const&>(*this);
817+
ref_.impl_.retrieve<BlockSize>(
818+
block, input_probe_begin, input_probe_end, output_probe, output_match, atomic_counter);
819+
}
820+
};
753821
} // namespace detail
754822
} // namespace cuco

include/cuco/detail/static_set/static_set.inl

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -525,25 +525,6 @@ static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::ret
525525
return impl_->retrieve(first, last, output_probe, output_match, this->ref(op::retrieve), stream);
526526
}
527527

528-
template <class Key,
529-
class Extent,
530-
cuda::thread_scope Scope,
531-
class KeyEqual,
532-
class ProbingScheme,
533-
class Allocator,
534-
class Storage>
535-
template <typename InputIt, typename OutputIt, typename ProbeEqual, typename ProbeHash>
536-
OutputIt static_set<Key, Extent, Scope, KeyEqual, ProbingScheme, Allocator, Storage>::retrieve(
537-
InputIt first,
538-
InputIt last,
539-
OutputIt output_begin,
540-
ProbeEqual const& probe_equal,
541-
ProbeHash const& probe_hash,
542-
cuda::stream_ref stream) const
543-
{
544-
CUCO_FAIL("Unsupported code path: retrieve with custom hash/equal");
545-
}
546-
547528
template <class Key,
548529
class Extent,
549530
cuda::thread_scope Scope,

include/cuco/static_map.cuh

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,40 @@ class static_map {
915915
template <typename InputIt>
916916
size_type count(InputIt first, InputIt last, cuda::stream_ref stream = {}) const;
917917

918+
/**
919+
* @brief Retrieves the matched key-value pair in the map corresponding to all probe keys in the
920+
* range
921+
* `[first, last)`
922+
*
923+
* If key `k = *(first + i)` has a match `m` in the map, copies a `cuco::pair{k, m}` to
924+
* unspecified locations in `[output_begin, output_end)`. Else, does nothing.
925+
*
926+
* @note This function synchronizes the given stream.
927+
* @note Behavior is undefined if the size of the output range exceeds
928+
* `std::distance(output_begin, output_end)`.
929+
* @note Behavior is undefined if the given key has multiple matches in the set.
930+
*
931+
* @tparam InputIt Device accessible input iterator
932+
* @tparam OutputProbeIt Device accessible output iterator whose `value_type` can be constructed
933+
* from `ProbeKey`
934+
* @tparam OutputMatchIt Device accessible output iterator whose `value_type` can be constructed
935+
* from map's `value_type`
936+
*
937+
* @param first Beginning of the sequence of probe keys
938+
* @param last End of the sequence of probe keys
939+
* @param output_probe Beginning of the sequence of the probe keys that have a match
940+
* @param output_match Beginning of the sequence of the matched key-value pairs
941+
* @param stream CUDA stream used for retrieve
942+
*
943+
* @return The iterator indicating the last valid pair in the output
944+
*/
945+
template <typename InputIt, typename OutputProbeIt, typename OutputMatchIt>
946+
std::pair<OutputProbeIt, OutputMatchIt> retrieve(InputIt first,
947+
InputIt last,
948+
OutputProbeIt output_probe,
949+
OutputMatchIt output_match,
950+
cuda::stream_ref stream = {}) const;
951+
918952
/**
919953
* @brief Retrieves all of the keys and their associated values contained in the map
920954
*

include/cuco/static_multimap.cuh

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,40 @@ class static_multimap {
666666
template <typename InputIt>
667667
size_type count(InputIt first, InputIt last, cuda::stream_ref stream = {}) const;
668668

669+
/**
670+
* @brief Retrieves the matched key-value pair in the multimap corresponding to all probe keys in
671+
* the range
672+
* `[first, last)`
673+
*
674+
* If key `k = *(first + i)` has a match `m` in the multimap, copies a `cuco::pair{k, m}` to
675+
* unspecified locations in `[output_begin, output_end)`. Else, does nothing.
676+
*
677+
* @note This function synchronizes the given stream.
678+
* @note Behavior is undefined if the size of the output range exceeds
679+
* `std::distance(output_begin, output_end)`.
680+
* @note Behavior is undefined if the given key has multiple matches in the set.
681+
*
682+
* @tparam InputIt Device accessible input iterator
683+
* @tparam OutputProbeIt Device accessible output iterator whose `value_type` can be constructed
684+
* from `ProbeKey`
685+
* @tparam OutputMatchIt Device accessible output iterator whose `value_type` can be constructed
686+
* from multimap's `value_type`
687+
*
688+
* @param first Beginning of the sequence of probe keys
689+
* @param last End of the sequence of probe keys
690+
* @param output_probe Beginning of the sequence of the probe keys that have a match
691+
* @param output_match Beginning of the sequence of the matched key-value pairs
692+
* @param stream CUDA stream used for retrieve
693+
*
694+
* @return The iterator indicating the last valid pair in the output
695+
*/
696+
template <typename InputIt, typename OutputProbeIt, typename OutputMatchIt>
697+
std::pair<OutputProbeIt, OutputMatchIt> retrieve(InputIt first,
698+
InputIt last,
699+
OutputProbeIt output_probe,
700+
OutputMatchIt output_match,
701+
cuda::stream_ref stream = {}) const;
702+
669703
/**
670704
* @brief Retrieves all of the keys and their associated values contained in the multimap
671705
*

include/cuco/static_set.cuh

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -842,44 +842,6 @@ class static_set {
842842
OutputIt2 output_match,
843843
cuda::stream_ref stream = {}) const;
844844

845-
/**
846-
* @brief Asynchronously retrieves the matched key in the set corresponding to all probe keys in
847-
* the range `[first, last)`
848-
*
849-
* If key `k = *(first + i)` has a match `m` in the set, copies a `cuco::pair{k, m}` to
850-
* unspecified locations in `[output_begin, output_end)`. Else, does nothing.
851-
*
852-
* @note Behavior is undefined if the size of the output range exceeds
853-
* `std::distance(output_begin, output_end)`.
854-
* @note Behavior is undefined if the given key has multiple matches in the set.
855-
*
856-
* @throw This API will always throw since it's not implemented.
857-
*
858-
* @tparam InputIt Device accessible input iterator
859-
* @tparam OutputIt Device accessible output iterator whose `value_type` can be constructed from
860-
* `cuco::pair<ProbeKey, Key>`
861-
* @tparam ProbeEqual Binary callable equal type
862-
* @tparam ProbeHash Unary callable hasher type that can be constructed from
863-
* an integer value
864-
*
865-
* @param first Beginning of the sequence of probe keys
866-
* @param last End of the sequence of probe keys
867-
* @param output_begin Beginning of the sequence of probe key and set key pairs retrieved for each
868-
* probe key
869-
* @param probe_equal The binary function to compare set keys and probe keys for equality
870-
* @param probe_hash The unary function to hash probe keys
871-
* @param stream CUDA stream used for retrieve
872-
*
873-
* @return The iterator indicating the last valid pair in the output
874-
*/
875-
template <typename InputIt, typename OutputIt, typename ProbeEqual, typename ProbeHash>
876-
OutputIt retrieve(InputIt first,
877-
InputIt last,
878-
OutputIt output_begin,
879-
ProbeEqual const& probe_equal = ProbeEqual{},
880-
ProbeHash const& probe_hash = ProbeHash{},
881-
cuda::stream_ref stream = {}) const;
882-
883845
/**
884846
* @brief Retrieves all keys contained in the set.
885847
*

tests/CMakeLists.txt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ ConfigureTest(STATIC_MAP_TEST
8989
static_map/key_sentinel_test.cu
9090
static_map/shared_memory_test.cu
9191
static_map/stream_test.cu
92-
static_map/rehash_test.cu)
92+
static_map/rehash_test.cu
93+
static_map/retrieve_test.cu)
9394

9495
###################################################################################################
9596
# - dynamic_map tests -----------------------------------------------------------------------------
@@ -114,15 +115,11 @@ ConfigureTest(STATIC_MULTISET_TEST
114115
# - static_multimap tests -------------------------------------------------------------------------
115116
ConfigureTest(STATIC_MULTIMAP_TEST
116117
static_multimap/count_test.cu
117-
static_multimap/custom_pair_retrieve_test.cu
118-
static_multimap/custom_type_test.cu
119118
static_multimap/find_test.cu
120119
static_multimap/heterogeneous_lookup_test.cu
121120
static_multimap/insert_contains_test.cu
122121
static_multimap/insert_if_test.cu
123122
static_multimap/multiplicity_test.cu
124-
static_multimap/non_match_test.cu
125-
static_multimap/pair_function_test.cu
126123
static_multimap/for_each_test.cu)
127124

128125
###################################################################################################

0 commit comments

Comments
 (0)