Skip to content

Commit 69817e2

Browse files
authored
Fix uniform random key generation (#630)
Fixes #532
1 parent 6c57bb1 commit 69817e2

File tree

1 file changed

+24
-15
lines changed

1 file changed

+24
-15
lines changed

include/cuco/utility/key_generator.cuh

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
#include <cuco/detail/pair/helpers.cuh>
2121
#include <cuco/detail/utility/strong_type.cuh>
2222

23-
#include <cuda/functional>
23+
#include <cuda/std/cmath>
24+
#include <cuda/std/functional> // TODO include <cuda/std/algorithm> instead once available
2425
#include <cuda/std/limits>
2526
#include <cuda/std/span>
2627
#include <thrust/device_vector.h>
@@ -80,6 +81,7 @@ struct gaussian : public cuco::detail::strong_type<double> {
8081
} // namespace distribution
8182

8283
namespace detail {
84+
8385
/**
8486
* @brief Generate uniform functor
8587
*
@@ -94,29 +96,37 @@ struct generate_uniform_fn {
9496
*
9597
* @param num Number of elements to generate
9698
* @param dist Random number distribution
99+
* @param seed Random seed
97100
*/
98-
__host__ __device__ constexpr generate_uniform_fn(std::size_t num, Dist dist)
99-
: num_{num}, dist_{dist}
101+
__host__ __device__ constexpr generate_uniform_fn(std::size_t num, Dist dist, std::size_t seed)
102+
: num_{num}, dist_{dist}, seed_{seed}
100103
{
101104
}
102105

103106
/**
104107
* @brief Generates a random number of type `T` based on the given `seed`
105108
*
106-
* @param seed Random number generator seed
109+
* @param idx Index of the output element
107110
*
108111
* @return A resulting random number
109112
*/
110-
__host__ __device__ constexpr T operator()(std::size_t seed) const noexcept
113+
__host__ __device__ constexpr T operator()(std::size_t idx) const noexcept
111114
{
112115
RNG rng;
113-
thrust::uniform_int_distribution<T> uniform_dist{1, static_cast<T>(num_ / dist_.value)};
114-
rng.seed(seed);
116+
// Improved seeding using a linear congruential generator
117+
rng.seed(seed_ + idx * 1664525ull + 1013904223ull);
118+
// Calculate number of unique keys
119+
auto num_unique_keys = cuda::std::max<size_t>(
120+
1ull,
121+
static_cast<size_t>(
122+
cuda::std::ceil(static_cast<double>(num_) / static_cast<double>(dist_.value))));
123+
thrust::uniform_int_distribution<T> uniform_dist{0, static_cast<T>(num_unique_keys - 1)};
115124
return uniform_dist(rng);
116125
}
117126

118-
std::size_t num_; ///< Number of elements to generate
119-
Dist dist_; ///< Random number distribution
127+
std::size_t num_; ///< Number of elements to generate
128+
Dist dist_; ///< Random number distribution
129+
std::size_t seed_; ///< Random seed
120130
};
121131

122132
/**
@@ -270,18 +280,17 @@ class key_generator {
270280
using value_type = typename std::iterator_traits<OutputIt>::value_type;
271281

272282
if constexpr (std::is_same_v<Dist, distribution::unique>) {
273-
thrust::sequence(exec_policy, out_begin, out_end, 0);
283+
thrust::sequence(exec_policy, out_begin, out_end, value_type{0});
274284
thrust::shuffle(exec_policy, out_begin, out_end, this->rng_);
275285
} else if constexpr (std::is_same_v<Dist, distribution::uniform>) {
276286
size_t num_keys = thrust::distance(out_begin, out_end);
277-
278-
thrust::counting_iterator<size_t> seeds(this->rng_());
287+
size_t seed = this->rng_();
279288

280289
thrust::transform(exec_policy,
281-
seeds,
282-
seeds + num_keys,
290+
thrust::make_counting_iterator<size_t>(0),
291+
thrust::make_counting_iterator<size_t>(num_keys),
283292
out_begin,
284-
detail::generate_uniform_fn<value_type, Dist, RNG>{num_keys, dist});
293+
detail::generate_uniform_fn<value_type, Dist, RNG>{num_keys, dist, seed});
285294
} else if constexpr (std::is_same_v<Dist, distribution::gaussian>) {
286295
size_t num_keys = thrust::distance(out_begin, out_end);
287296

0 commit comments

Comments
 (0)