2626#include < cuco/utility/traits.hpp>
2727
2828#include < cuda/atomic>
29+ #include < cuda/functional>
2930#include < cuda/std/__algorithm/max.h> // TODO #include <cuda/std/algorithm> once available
3031#include < cuda/std/bit>
3132#include < cuda/std/cstddef>
3233#include < cuda/std/span>
3334#include < cuda/std/utility>
3435#include < cuda/stream_ref>
36+ #include < thrust/iterator/constant_iterator.h>
3537#include < thrust/type_traits/is_contiguous_iterator.h>
3638
3739#include < cooperative_groups.h>
@@ -172,6 +174,60 @@ class hyperloglog_impl {
172174 */
173175 template <class InputIt >
174176 __host__ constexpr void add_async (InputIt first, InputIt last, cuda::stream_ref stream)
177+ {
178+ auto const always_true = thrust::constant_iterator<bool >(true );
179+ this ->add_if_async (first, last, always_true, cuda::std::identity{}, stream);
180+ }
181+
182+ /* *
183+ * @brief Adds to be counted items to the estimator.
184+ *
185+ * @note This function synchronizes the given stream. For asynchronous execution use
186+ * `add_async`.
187+ *
188+ * @tparam InputIt Device accessible random access input iterator where
189+ * <tt>std::is_convertible<std::iterator_traits<InputIt>::value_type,
190+ * T></tt> is `true`
191+ *
192+ * @param first Beginning of the sequence of items
193+ * @param last End of the sequence of items
194+ * @param stream CUDA stream this operation is executed in
195+ */
196+ template <class InputIt >
197+ __host__ constexpr void add (InputIt first, InputIt last, cuda::stream_ref stream)
198+ {
199+ this ->add_async (first, last, stream);
200+ #if CCCL_MAJOR_VERSION > 3 || (CCCL_MAJOR_VERSION == 3 && CCCL_MINOR_VERSION >= 1)
201+ stream.sync ();
202+ #else
203+ stream.wait ();
204+ #endif
205+ }
206+
207+ /* *
208+ * @brief Asynchronously adds items in the range `[first, last)` if `pred` of the corresponding
209+ * stencil returns true.
210+ *
211+ * @note The item `*(first + i)` is added if `pred( *(stencil + i) )` returns true.
212+ *
213+ * @tparam InputIt Device accessible random access input iterator where
214+ * <tt>std::is_convertible<std::iterator_traits<InputIt>::value_type,
215+ * T></tt> is `true`
216+ * @tparam StencilIt Device accessible random access iterator whose value_type is
217+ * convertible to Predicate's argument type
218+ * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` and
219+ * argument type is convertible from <tt>std::iterator_traits<StencilIt>::value_type</tt>
220+ *
221+ * @param first Beginning of the sequence of items
222+ * @param last End of the sequence of items
223+ * @param stencil Beginning of the stencil sequence
224+ * @param pred Predicate to test on every element in the range `[stencil, stencil +
225+ * std::distance(first, last))`
226+ * @param stream CUDA stream this operation is executed in
227+ */
228+ template <class InputIt , class StencilIt , class Predicate >
229+ __host__ constexpr void add_if_async (
230+ InputIt first, InputIt last, StencilIt stencil, Predicate pred, cuda::stream_ref stream)
175231 {
176232 auto const num_items = cuco::detail::distance (first, last);
177233 if (num_items == 0 ) { return ; }
@@ -181,8 +237,6 @@ class hyperloglog_impl {
181237 int const shmem_bytes = sketch_bytes ();
182238 void const * kernel = nullptr ;
183239
184- // In case the input iterator represents a contiguous memory segment we can employ efficient
185- // vectorized loads
186240 if constexpr (thrust::is_contiguous_iterator_v<InputIt>) {
187241 auto const ptr = thrust::raw_pointer_cast (&first[0 ]);
188242 auto constexpr max_vector_bytes = 32 ;
@@ -193,54 +247,60 @@ class hyperloglog_impl {
193247 switch (vector_size) {
194248 case 2 :
195249 kernel = reinterpret_cast <void const *>(
196- cuco::hyperloglog_ns::detail::add_shmem_vectorized<2 , hyperloglog_impl>);
250+ cuco::hyperloglog_ns::detail::
251+ add_if_shmem_vectorized<2 , StencilIt, Predicate, hyperloglog_impl>);
197252 break ;
198253 case 4 :
199254 kernel = reinterpret_cast <void const *>(
200- cuco::hyperloglog_ns::detail::add_shmem_vectorized<4 , hyperloglog_impl>);
255+ cuco::hyperloglog_ns::detail::
256+ add_if_shmem_vectorized<4 , StencilIt, Predicate, hyperloglog_impl>);
201257 break ;
202258 case 8 :
203259 kernel = reinterpret_cast <void const *>(
204- cuco::hyperloglog_ns::detail::add_shmem_vectorized<8 , hyperloglog_impl>);
260+ cuco::hyperloglog_ns::detail::
261+ add_if_shmem_vectorized<8 , StencilIt, Predicate, hyperloglog_impl>);
205262 break ;
206263 case 16 :
207264 kernel = reinterpret_cast <void const *>(
208- cuco::hyperloglog_ns::detail::add_shmem_vectorized<16 , hyperloglog_impl>);
265+ cuco::hyperloglog_ns::detail::
266+ add_if_shmem_vectorized<16 , StencilIt, Predicate, hyperloglog_impl>);
209267 break ;
210268 };
211269 }
212270
213271 if (kernel != nullptr and this ->try_reserve_shmem (kernel, shmem_bytes)) {
214272 if constexpr (thrust::is_contiguous_iterator_v<InputIt>) {
215- // We make use of the occupancy calculator to get the minimum number of blocks which still
216- // saturates the GPU. This reduces the shmem initialization overhead and atomic contention
217- // on the final register array during the merge phase.
218273 CUCO_CUDA_TRY (
219274 cudaOccupancyMaxPotentialBlockSize (&grid_size, &block_size, kernel, shmem_bytes));
220275
221276 auto const ptr = thrust::raw_pointer_cast (&first[0 ]);
222- void * kernel_args[] = {
223- (void *)(&ptr), // TODO can't use reinterpret_cast since it can't cast away const
224- (void *)(&num_items),
225- reinterpret_cast <void *>(this )};
277+ void * kernel_args[] = {(void *)(&ptr),
278+ (void *)(&num_items),
279+ (void *)(&stencil),
280+ (void *)(&pred),
281+ reinterpret_cast <void *>(this )};
226282 CUCO_CUDA_TRY (
227283 cudaLaunchKernel (kernel, grid_size, block_size, kernel_args, shmem_bytes, stream.get ()));
228284 }
229285 } else {
230286 kernel = reinterpret_cast <void const *>(
231- cuco::hyperloglog_ns::detail::add_shmem<InputIt, hyperloglog_impl>);
232- void * kernel_args[] = {(void *)(&first), (void *)(&num_items), reinterpret_cast <void *>(this )};
287+ cuco::hyperloglog_ns::detail::
288+ add_if_shmem<InputIt, StencilIt, Predicate, hyperloglog_impl>);
289+ void * kernel_args[] = {(void *)(&first),
290+ (void *)(&num_items),
291+ (void *)(&stencil),
292+ (void *)(&pred),
293+ reinterpret_cast <void *>(this )};
233294 if (this ->try_reserve_shmem (kernel, shmem_bytes)) {
234295 CUCO_CUDA_TRY (
235296 cudaOccupancyMaxPotentialBlockSize (&grid_size, &block_size, kernel, shmem_bytes));
236297
237298 CUCO_CUDA_TRY (
238299 cudaLaunchKernel (kernel, grid_size, block_size, kernel_args, shmem_bytes, stream.get ()));
239300 } else {
240- // Computes sketch directly in global memory. (Fallback path in case there is not enough
241- // shared memory avalable)
242301 kernel = reinterpret_cast <void const *>(
243- cuco::hyperloglog_ns::detail::add_gmem<InputIt, hyperloglog_impl>);
302+ cuco::hyperloglog_ns::detail::
303+ add_if_gmem<InputIt, StencilIt, Predicate, hyperloglog_impl>);
244304
245305 CUCO_CUDA_TRY (cudaOccupancyMaxPotentialBlockSize (&grid_size, &block_size, kernel, 0 ));
246306
@@ -250,31 +310,6 @@ class hyperloglog_impl {
250310 }
251311 }
252312
253- /* *
254- * @brief Adds to be counted items to the estimator.
255- *
256- * @note This function synchronizes the given stream. For asynchronous execution use
257- * `add_async`.
258- *
259- * @tparam InputIt Device accessible random access input iterator where
260- * <tt>std::is_convertible<std::iterator_traits<InputIt>::value_type,
261- * T></tt> is `true`
262- *
263- * @param first Beginning of the sequence of items
264- * @param last End of the sequence of items
265- * @param stream CUDA stream this operation is executed in
266- */
267- template <class InputIt >
268- __host__ constexpr void add (InputIt first, InputIt last, cuda::stream_ref stream)
269- {
270- this ->add_async (first, last, stream);
271- #if CCCL_MAJOR_VERSION > 3 || (CCCL_MAJOR_VERSION == 3 && CCCL_MINOR_VERSION >= 1)
272- stream.sync ();
273- #else
274- stream.wait ();
275- #endif
276- }
277-
278313 /* *
279314 * @brief Merges the result of `other` estimator reference into `*this` estimator reference.
280315 *
0 commit comments