Skip to content

Commit 9ed2bb8

Browse files
committed
fix(prng): fix UB in sample_random_variables to accept only generator as reference
1 parent ff13c73 commit 9ed2bb8

File tree

4 files changed

+51
-12
lines changed

4 files changed

+51
-12
lines changed

apps/cli/src/main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ int parse_callback_ok(
118118
h->set_logger(std::cref(logger));
119119
// h->set_feed_constant(7.701635339554948e-06, 2, 0, 0);
120120
const auto load_serde = user_params.load_serde;
121-
user_params.uniform_mc_init = false;
121+
// user_params.uniform_mc_init = false;
122122
INTERPRETER_INIT
123123
{
124124
HANDLE_RC(h->register_parameters(

apps/libs/mc/public/mc/prng/prng.hpp

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,53 @@ namespace MC
1414
Kokkos::Random_XorShift1024_Pool<Kokkos::DefaultExecutionSpace>;
1515
pool_type get_pool(std::size_t seed = 0);
1616

17-
/** @brief Sample variable with RAI mecanism*/
18-
KOKKOS_INLINE_FUNCTION auto sample_random_variables(const pool_type& p,
19-
auto functor)
17+
/**
18+
* @brief Samples random variables
19+
* Use t wrap generator with RAII
20+
*
21+
* The functor must accept the generator by reference, not by value.
22+
* Passing the generator by value (e.g., `[](auto gen){ ... }`) results in
23+
* undefined behavior, because copying the generator state is not allowed.
24+
* Instead, ensure the functor signature is `[](auto& gen){ ... }`.
25+
*
26+
* @tparam Functor A callable object that takes `generator_type&` and returns
27+
* any tuple-like or user-defined result.
28+
* @param pool Random pool
29+
* @param functor A callable that operates on the generator and produces the
30+
* desired random values.
31+
*
32+
* @return The value returned by the user-provided functor.
33+
*
34+
* @example
35+
* // Example usage:
36+
* const auto [a, b] = MC::sample_random_variables(
37+
* random_pool,
38+
* [](auto& gen)
39+
* {
40+
* const auto a = gen.urand64();
41+
* const auto b = gen.drand();
42+
* return std::make_tuple(a, b);
43+
* });
44+
*/
45+
46+
KOKKOS_INLINE_FUNCTION auto sample_random_variables(const pool_type& pool,
47+
auto&& functor)
2048
{
21-
auto gen = p.get_state();
22-
auto t = functor(gen);
23-
p.free_state(gen);
24-
return t;
49+
// Be sure that generator is passed to functor as reference
50+
// As functor can be lambda usually [](auto gen){return gen.rand();};
51+
// If auto is passed as value, value generated is UB. Need auto& to avoid
52+
// any problem
53+
using gen_t = typename pool_type::generator_type;
54+
using Functor = decltype(functor);
55+
static_assert(std::is_invocable_v<Functor, gen_t&> &&
56+
!std::is_invocable_v<Functor, gen_t>,
57+
"Functor must accept generator by reference only");
58+
59+
auto gen = pool.get_state();
60+
const auto result = functor(gen);
61+
pool.free_state(gen);
62+
63+
return result;
2564
}
2665

2766
#define SAMPLE_RANDOM_VARIABLES(_random_pool_, ...) \
@@ -89,7 +128,7 @@ namespace MC
89128
{
90129

91130
return sample_random_variables(
92-
random_pool, [a, b](auto gen) { return gen.urand64(a, b); });
131+
random_pool, [a, b](auto& gen) { return gen.urand64(a, b); });
93132
}
94133

95134
pool_type random_pool;

apps/libs/models/public/models/monod.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ namespace Models
8888

8989
const auto [l0, mu] = MC::sample_random_variables(
9090
random_pool,
91-
[](auto gen)
91+
[](auto& gen)
9292
{
9393
const auto l0 = local_dist.draw(
9494
gen); // Get initial fro given distribution (normal)

tools/cases.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
<!-- <cma_case_path >/home-local/casale/Documents/thesis/cfd-cma/cma_data/sanofi/</cma_case_path> -->
88

99
<final_time>30</final_time>
10-
<number_particle>500000</number_particle>
10+
<number_particle>5000</number_particle>
1111
<delta_time>0.001</delta_time>
1212
<results_file_name>debug</results_file_name>
13-
<number_exported_result>0</number_exported_result>
13+
<number_exported_result>10</number_exported_result>
1414
<model_name>monod</model_name>
1515
<!--<initialiser_path>./cma_data/0d_4s_init.h5</initialiser_path> -->
1616
<!-- <initialiser_path>./cma_data/n14_init.h5</initialiser_path> -->

0 commit comments

Comments
 (0)