Skip to content

Commit fb08b5e

Browse files
Refactor c2h gen to ensure teardown before main (#7067)
Fixes: #7063
1 parent 3a008bc commit fb08b5e

File tree

5 files changed

+76
-56
lines changed

5 files changed

+76
-56
lines changed

c2h/generators.cu

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
#include <thrust/iterator/transform_iterator.h>
1010
#include <thrust/tabulate.h>
1111

12+
#include <cuda/std/optional>
13+
1214
#include <c2h/bfloat16.cuh>
1315
#include <c2h/custom_type.h>
1416
#include <c2h/detail/generators.cuh>
@@ -44,29 +46,77 @@ struct i_to_rnd_t
4446
};
4547
#endif // !C2H_HAS_CURAND
4648

47-
void generator_t::generate()
49+
class generator_t
4850
{
51+
public:
52+
generator_t()
53+
{
54+
#if C2H_HAS_CURAND
55+
curandCreateGenerator(&m_gen, CURAND_RNG_PSEUDO_DEFAULT);
56+
#endif
57+
}
58+
59+
~generator_t()
60+
{
61+
#if C2H_HAS_CURAND
62+
curandDestroyGenerator(m_gen);
63+
#endif
64+
}
65+
66+
float* prepare_random_generator(seed_t seed, std::size_t num_items)
67+
{
68+
m_distribution.resize(num_items);
69+
4970
#if C2H_HAS_CURAND
50-
curandGenerateUniform(m_gen, thrust::raw_pointer_cast(m_distribution.data()), m_distribution.size());
71+
curandSetPseudoRandomGeneratorSeed(m_gen, seed.get());
5172
#else
52-
thrust::tabulate(device_policy, m_distribution.begin(), m_distribution.end(), i_to_rnd_t{m_re});
53-
m_re.discard(m_distribution.size());
73+
m_gen.seed(seed.get());
5474
#endif
55-
}
5675

57-
float* generator_t::prepare_random_generator(seed_t seed, std::size_t num_items)
58-
{
59-
m_distribution.resize(num_items);
76+
generate();
6077

78+
return thrust::raw_pointer_cast(m_distribution.data());
79+
}
80+
81+
// re-fills the currently held distribution vector with new random values
82+
void generate()
83+
{
84+
#if C2H_HAS_CURAND
85+
curandGenerateUniform(m_gen, thrust::raw_pointer_cast(m_distribution.data()), m_distribution.size());
86+
#else
87+
thrust::tabulate(device_policy, m_distribution.begin(), m_distribution.end(), i_to_rnd_t{m_gen});
88+
m_gen.discard(m_distribution.size());
89+
#endif
90+
}
91+
92+
private:
6193
#if C2H_HAS_CURAND
62-
curandSetPseudoRandomGeneratorSeed(m_gen, seed.get());
94+
curandGenerator_t
6395
#else
64-
m_re.seed(seed.get());
96+
thrust::default_random_engine
6597
#endif
98+
m_gen;
99+
c2h::device_vector<float> m_distribution;
100+
};
66101

67-
generate();
102+
// global generator state
103+
cuda::std::optional<generator_t> generator;
68104

69-
return thrust::raw_pointer_cast(m_distribution.data());
105+
void init_generator()
106+
{
107+
_CCCL_VERIFY(!generator.has_value(), "");
108+
generator.emplace();
109+
}
110+
111+
float* prepare_random_data(seed_t seed, std::size_t num_items)
112+
{
113+
return generator.value().prepare_random_generator(seed, num_items);
114+
}
115+
116+
void cleanup_generator()
117+
{
118+
_CCCL_VERIFY(generator.has_value(), "");
119+
generator.reset();
70120
}
71121

72122
struct random_to_custom_t
@@ -94,7 +144,7 @@ void gen_custom_type_state(
94144
std::size_t element_size)
95145
{
96146
// FIXME(bgruber): implement min/max handling for custom_type_state_t
97-
float* d_in = generator.prepare_random_generator(seed, elements * 2);
147+
float* d_in = prepare_random_data(seed, elements * 2);
98148
thrust::for_each(device_policy,
99149
thrust::counting_iterator<std::size_t>{0},
100150
thrust::counting_iterator<std::size_t>{elements},

c2h/generators_gen_values.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace c2h::detail
1616
template <typename T>
1717
void gen_values_between(seed_t seed, ::cuda::std::span<T> data, T min, T max)
1818
{
19-
const auto* dist = generator.prepare_random_generator(seed, data.size());
19+
const auto* dist = prepare_random_data(seed, data.size());
2020
thrust::transform(device_policy, dist, dist + data.size(), data.begin(), random_to_item_t<T>(min, max));
2121
}
2222

c2h/generators_vector.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ struct random_to_vec_item_t
5353
template <> \
5454
void gen_values_between(seed_t seed, ::cuda::std::span<T> data, T min, T max) \
5555
{ \
56-
const auto* dist = generator.prepare_random_generator(seed, data.size()); \
56+
const auto* dist = prepare_random_data(seed, data.size()); \
5757
auto op = random_to_vec_item_t<T, ::cuda::std::tuple_size_v<T>>{min, max, dist, data.data()}; \
5858
thrust::for_each( \
5959
device_policy, thrust::counting_iterator<size_t>{0}, thrust::counting_iterator<size_t>{data.size()}, op); \

c2h/include/c2h/catch2_main.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
#include <thrust/detail/config/device_system.h>
77

8-
#include <iostream>
8+
#include <c2h/detail/generators.cuh>
99

1010
//! @file
1111
//! This file includes a custom Catch2 main function. When CMake is configured to build each test as a separate
@@ -43,6 +43,9 @@ int main(int argc, char* argv[])
4343

4444
set_device(device_id);
4545
# endif // THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA
46-
return session.run();
46+
c2h::detail::init_generator();
47+
const auto ret = session.run();
48+
c2h::detail::cleanup_generator();
49+
return ret;
4750
}
4851
#endif // C2H_CONFIG_MAIN

c2h/include/c2h/detail/generators.cuh

Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,50 +4,17 @@
44
#include <cuda/std/complex>
55

66
#include <c2h/generators.h>
7-
#include <c2h/vector.h>
8-
9-
#if C2H_HAS_CURAND
10-
# include <curand.h>
11-
#else
12-
# include <thrust/random.h>
13-
#endif
147

158
namespace c2h::detail
169
{
17-
class generator_t
18-
{
19-
public:
20-
generator_t()
21-
{
22-
#if C2H_HAS_CURAND
23-
curandCreateGenerator(&m_gen, CURAND_RNG_PSEUDO_DEFAULT);
24-
#endif
25-
}
26-
27-
~generator_t()
28-
{
29-
#if C2H_HAS_CURAND
30-
curandDestroyGenerator(m_gen);
31-
#endif
32-
}
33-
34-
// sets the seed and resizes the distribution vector, fills it by calling generate(), and returns a pointer the start
35-
// of the data
36-
float* prepare_random_generator(seed_t seed, std::size_t num_items);
10+
// called once from main to set up the generator state
11+
void init_generator();
3712

38-
// re-fills the currently held distribution vector with new random values
39-
void generate();
40-
41-
private:
42-
#if C2H_HAS_CURAND
43-
curandGenerator_t m_gen;
44-
#else
45-
thrust::default_random_engine m_re;
46-
#endif
47-
c2h::device_vector<float> m_distribution;
48-
};
13+
// sets the seed and resizes the distribution vector, fills it, and returns a pointer the start of the data
14+
float* prepare_random_data(seed_t seed, std::size_t num_items);
4915

50-
inline generator_t generator;
16+
// called once before main returns to clean up the generator state
17+
void cleanup_generator();
5118

5219
template <typename T, bool = ::cuda::is_floating_point_v<T>>
5320
struct random_to_item_t

0 commit comments

Comments
 (0)