1
1
# Default RNG type for when nothing is specified
2
- const _BASE_RNG = Random123. Philox2x
2
+ const _BASE_RNG = Random123. Philox2x # Rng with state bigger than 1 are broken because of Split
3
3
4
4
"""
5
5
TracedRNG{R,N,T}
6
6
7
7
Wrapped random number generator from Random123 to keep track of random streams during model evaluation
8
8
"""
9
- mutable struct TracedRNG{R,N,T<: Random123.AbstractR123{R} } <: Random.AbstractRNG
9
+ mutable struct TracedRNG{R,N,T<: Random123.AbstractR123 } <: Random.AbstractRNG
10
10
" Model step counter"
11
11
count:: Int
12
12
" Inner RNG"
@@ -21,7 +21,7 @@ Create a `TracedRNG` with `r` as the inner RNG.
21
21
"""
22
22
function TracedRNG (r:: Random123.AbstractR123 = _BASE_RNG ())
23
23
Random123. set_counter! (r, 0 )
24
- return TracedRNG (1 , r, typeof (r . key )[])
24
+ return TracedRNG (1 , r, Random123 . seed_type (r )[])
25
25
end
26
26
27
27
# Connect to the Random API
@@ -50,7 +50,7 @@ function load_state!(rng::TracedRNG)
50
50
end
51
51
52
52
"""
53
- update_rng !(rng::TracedRNG)
53
+ Random.seed !(rng::TracedRNG, key )
54
54
55
55
Set key and counter of inner rng in `rng` to `key` and the running model step to 0
56
56
"""
@@ -59,15 +59,33 @@ function Random.seed!(rng::TracedRNG, key)
59
59
return Random123. set_counter! (rng. rng, 0 )
60
60
end
61
61
62
+ """
63
+ gen_seed(rng::Random.AbstractRNG, subrng::TracedRNG, sampler::Random.Sampler)
64
+
65
+ Generate a `seed` for the subrng based on top-level `rng` and `sampler`
66
+ """
67
+ function gen_seed (rng:: Random.AbstractRNG , :: TracedRNG{<:Integer} , sampler:: Random.Sampler )
68
+ return rand (rng, sampler)
69
+ end
70
+
71
+ function gen_seed (
72
+ rng:: Random.AbstractRNG , :: TracedRNG{<:NTuple{N}} , sampler:: Random.Sampler
73
+ ) where {N}
74
+ return Tuple (rand (rng, sampler, N))
75
+ end
76
+
62
77
"""
63
78
save_state!(r::TracedRNG)
64
79
65
80
Add current key of the inner rng in `r` to `keys`.
66
81
"""
67
82
function save_state! (r:: TracedRNG )
68
- return push! (r. keys, r. rng. key )
83
+ return push! (r. keys, state ( r. rng) )
69
84
end
70
85
86
+ state (rng:: Random123.Philox2x ) = rng. key
87
+ state (rng:: Random123.Philox4x ) = (rng. key1, rng. key2)
88
+
71
89
Base. copy (r:: TracedRNG ) = TracedRNG (r. count, copy (r. rng), deepcopy (r. keys))
72
90
73
91
"""
0 commit comments