@@ -14,6 +14,131 @@ use crate::sims_rngs::choices::Choices;
1414use rand:: distr:: Bernoulli ;
1515use rand:: prelude:: * ;
1616
17+ /// A utility struct that provides methods for common random operations
18+ ///
19+ /// This struct wraps a random number generator and provides methods for
20+ /// common operations like coin flips, weighted choices, and generating
21+ /// collections of random values.
22+ pub struct RandomUtils < R : Rng > {
23+ rng : R ,
24+ }
25+
26+ impl < R : Rng > RandomUtils < R > {
27+ /// Create a new `RandomUtils` with the given random number generator
28+ ///
29+ /// # Arguments
30+ /// * `rng` - The random number generator to use
31+ #[ inline]
32+ pub fn new ( rng : R ) -> Self {
33+ Self { rng }
34+ }
35+
36+ /// Get a mutable reference to the internal RNG
37+ ///
38+ /// This is useful when you need to perform operations not directly
39+ /// provided by the `RandomUtils` methods.
40+ #[ inline]
41+ pub fn rng_mut ( & mut self ) -> & mut R {
42+ & mut self . rng
43+ }
44+
45+ /// Choose between options given weighted probabilities.
46+ ///
47+ /// # Arguments
48+ /// * `choices` - The choices to select from with their associated weights
49+ ///
50+ /// # Returns
51+ /// A reference to the selected item
52+ #[ inline]
53+ pub fn choose_weighted < ' a , T > ( & mut self , choices : & ' a Choices < T > ) -> & ' a T {
54+ choices. sample ( & mut self . rng )
55+ }
56+
57+ /// Simulates a fair coin flip with 50% probability of true/false
58+ ///
59+ /// # Returns
60+ /// `true` or `false` with equal probability
61+ #[ inline]
62+ #[ allow( clippy:: cast_possible_wrap, clippy:: as_conversions) ]
63+ pub fn coin_flip ( & mut self ) -> bool {
64+ ( self . rng . next_u32 ( ) as i32 ) < 0
65+ }
66+
67+ /// Simulates a biased coin flip with specified probability of true
68+ ///
69+ /// # Arguments
70+ /// * `p` - The probability of returning `true` (between 0.0 and 1.0)
71+ ///
72+ /// # Returns
73+ /// `true` with probability `p`, `false` with probability `1-p`
74+ ///
75+ /// # Panics
76+ /// This function will panic if `p` is not a valid probability
77+ /// (not between 0.0 and 1.0, inclusive).
78+ #[ inline]
79+ pub fn biased_coin_flip ( & mut self , p : f64 ) -> bool {
80+ let bernoulli = Bernoulli :: new ( p)
81+ . expect ( "Failed to create Bernoulli distribution due to invalid probability" ) ;
82+ bernoulli. sample ( & mut self . rng )
83+ }
84+
85+ /// Generates a vector of bools, where true has an independent probability of `p`.
86+ ///
87+ /// # Arguments
88+ /// * `p` - The probability of each element being `true` (between 0.0 and 1.0)
89+ /// * `n` - The number of bools to generate
90+ ///
91+ /// # Returns
92+ /// A vector of `n` bools where each is independently `true` with probability `p`
93+ ///
94+ /// # Panics
95+ /// This function will panic if `p` is not a valid probability
96+ /// (not between 0.0 and 1.0, inclusive).
97+ #[ inline]
98+ pub fn gen_bools ( & mut self , p : f64 , n : usize ) -> Vec < bool > {
99+ let bernoulli = Bernoulli :: new ( p)
100+ . expect ( "Failed to create Bernoulli distribution due to invalid probability" ) ;
101+ ( 0 ..n) . map ( |_| bernoulli. sample ( & mut self . rng ) ) . collect ( )
102+ }
103+
104+ /// Select a random index based on a weighted probability distribution
105+ ///
106+ /// # Arguments
107+ /// * `weights` - A slice of weights (values should be non-negative)
108+ ///
109+ /// # Returns
110+ /// A randomly selected index where the probability of each index is proportional to its weight
111+ ///
112+ /// # Panics
113+ /// This function will panic if:
114+ /// - The weights slice is empty
115+ /// - All weights are zero
116+ /// - Any weight is negative
117+ #[ inline]
118+ pub fn weighted_index ( & mut self , weights : & [ f64 ] ) -> usize {
119+ assert ! ( !weights. is_empty( ) , "Cannot select from empty weights" ) ;
120+
121+ let total: f64 = weights. iter ( ) . sum ( ) ;
122+ assert ! ( total > 0.0 , "Sum of weights must be positive" ) ;
123+
124+ let mut target = self . rng . random_range ( 0.0 ..total) ;
125+
126+ for ( i, & weight) in weights. iter ( ) . enumerate ( ) {
127+ assert ! ( weight >= 0.0 , "Weights must be non-negative" ) ;
128+ target -= weight;
129+ if target <= 0.0 {
130+ return i;
131+ }
132+ }
133+
134+ // This should never happen due to floating-point arithmetic
135+ weights. len ( ) - 1
136+ }
137+ }
138+
139+ // Keep backwards compatibility with the original functions
140+ // by providing standalone versions that delegate to RandomUtils
141+
17142/// Choose between options given weighted probabilities.
18143#[ inline]
19144pub fn choose_weighted < ' a , T , R : Rng > ( rng : & mut R , choices : & ' a Choices < T > ) -> & ' a T {
@@ -38,3 +163,64 @@ pub fn gen_bools<R: Rng>(rng: &mut R, p: f64, n: usize) -> Vec<bool> {
38163 . expect ( "Failed to create Bernoulli distribution due to invalid probability" ) ;
39164 ( 0 ..n) . map ( |_| bernoulli. sample ( rng) ) . collect ( )
40165}
166+
167+ #[ cfg( test) ]
168+ mod tests {
169+ use super :: * ;
170+ use rand:: SeedableRng ;
171+ use rand_chacha:: ChaCha8Rng ;
172+
173+ #[ test]
174+ fn test_random_utils_struct ( ) {
175+ // Create a seeded RNG for deterministic tests
176+ let rng = ChaCha8Rng :: seed_from_u64 ( 42 ) ;
177+ let mut random_utils = RandomUtils :: new ( rng) ;
178+
179+ // Test coin_flip
180+ let flips: Vec < bool > = ( 0 ..100 ) . map ( |_| random_utils. coin_flip ( ) ) . collect ( ) ;
181+ let true_count = flips. iter ( ) . filter ( |& & b| b) . count ( ) ;
182+ // With a fair coin, we expect roughly 50 trues, but there's randomness
183+ assert ! ( true_count > 30 && true_count < 70 ) ;
184+
185+ // Test biased_coin_flip
186+ let biased_flips: Vec < bool > = ( 0 ..100 )
187+ . map ( |_| random_utils. biased_coin_flip ( 0.7 ) )
188+ . collect ( ) ;
189+ let biased_true_count = biased_flips. iter ( ) . filter ( |& & b| b) . count ( ) ;
190+ // With p=0.7, we expect roughly 70 trues, but there's randomness
191+ assert ! ( biased_true_count > 50 && biased_true_count < 90 ) ;
192+
193+ // Test gen_bools
194+ let bools = random_utils. gen_bools ( 0.3 , 50 ) ;
195+ assert_eq ! ( bools. len( ) , 50 ) ;
196+ let bools_true_count = bools. iter ( ) . filter ( |& & b| b) . count ( ) ;
197+ // With p=0.3, we expect roughly 15 trues, but there's randomness
198+ assert ! ( bools_true_count > 5 && bools_true_count < 25 ) ;
199+
200+ // Test weighted_index
201+ let weights = [ 1.0 , 3.0 , 6.0 ] ;
202+ let mut counts = [ 0 , 0 , 0 ] ;
203+ for _ in 0 ..1000 {
204+ counts[ random_utils. weighted_index ( & weights) ] += 1 ;
205+ }
206+ // The middle value should be about 3x the first, and the last about 6x the first
207+ assert ! ( counts[ 1 ] > counts[ 0 ] * 2 ) ;
208+ assert ! ( counts[ 2 ] > counts[ 1 ] ) ;
209+ }
210+
211+ #[ test]
212+ #[ should_panic( expected = "Cannot select from empty weights" ) ]
213+ fn test_weighted_index_empty ( ) {
214+ let rng = ChaCha8Rng :: seed_from_u64 ( 42 ) ;
215+ let mut random_utils = RandomUtils :: new ( rng) ;
216+ random_utils. weighted_index ( & [ ] ) ;
217+ }
218+
219+ #[ test]
220+ #[ should_panic( expected = "Sum of weights must be positive" ) ]
221+ fn test_weighted_index_all_zeros ( ) {
222+ let rng = ChaCha8Rng :: seed_from_u64 ( 42 ) ;
223+ let mut random_utils = RandomUtils :: new ( rng) ;
224+ random_utils. weighted_index ( & [ 0.0 , 0.0 , 0.0 ] ) ;
225+ }
226+ }
0 commit comments