@@ -5,8 +5,8 @@ using Distributions
5
5
import Base. rand
6
6
import Random. seed!
7
7
8
- # Use Philox2x for now
9
- BASE_RNG = Philox2x
8
+ # Default RNG type for when nothing is specified
9
+ _BASE_RNG = Philox2x
10
10
11
11
"""
12
12
TracedRNG{R,T}
@@ -15,14 +15,22 @@ Wrapped random number generator from Random123 to keep track of random streams d
15
15
"""
16
16
mutable struct TracedRNG{T} < :
17
17
Random. AbstractRNG where {T<: (Random123.AbstractR123{R} where {R}) }
18
+ " Model step counter"
18
19
count:: Int
20
+ " Inner RNG"
19
21
rng:: T
22
+ " Array of keys"
20
23
keys
21
- counters
22
24
end
23
25
26
+ """
27
+ TracedRNG(r::Random123.AbstractR123)
28
+
29
+ Initialize TracedRNG with r as the inner RNG
30
+ """
24
31
function TracedRNG (r:: Random123.AbstractR123 )
25
- return TracedRNG (1 , r, typeof (r. key)[], typeof (r. ctr1)[])
32
+ set_counter! (r, 1 )
33
+ return TracedRNG (1 , r, typeof (r. key)[])
26
34
end
27
35
28
36
"""
31
39
Create a default TracedRNG
32
40
"""
33
41
function TracedRNG ()
34
- r = BASE_RNG ()
42
+ r = _BASE_RNG ()
35
43
return TracedRNG (r)
36
44
end
37
45
38
- # Plug into Random
46
+ # Connect to the Random API
39
47
Random. rng_native_52 (rng:: TracedRNG{U} ) where {U} = Random. rng_native_52 (rng. rng)
40
48
Base. rand (rng:: TracedRNG{U} , :: Type{T} ) where {U,T} = Base. rand (rng. rng, T)
41
49
42
50
"""
43
51
split(r::TracedRNG, n::Integer)
44
52
45
- Split keys of the internal Philox2x into n distinct seeds
53
+ Split inner RNG into n new TracedRNG
46
54
"""
47
55
function split (r:: TracedRNG{T} , n:: Integer ) where {T}
48
- n == 1 && return [r. rng. key]
49
56
return map (i -> hash (r. rng. key, convert (UInt, r. rng. ctr1 + i)), 1 : n)
50
57
end
51
58
52
59
"""
53
- update_rng !(r::TracedRNG, seed::Number)
60
+ seed !(r::TracedRNG, seed::Number)
54
61
55
- Set the key of the wrapped Philox2x rng
62
+ Set the key of the inner RNG as `seed`
56
63
"""
57
64
function seed! (r:: TracedRNG{T} , seed) where {T}
58
65
return seed! (r. rng, seed)
59
66
end
60
67
61
68
"""
62
- reset_rng (r::TracedRNG, seed)
69
+ load_state (r::TracedRNG, seed)
63
70
64
- Reset the rng to the running model step
71
+ Load state from current model iteration. Random streams are now replayed
65
72
"""
66
- function reset_rng! (rng:: TracedRNG{T} ) where {T}
73
+ function load_state (rng:: TracedRNG{T} ) where {T}
67
74
key = rng. keys[rng. count]
68
- ctr = rng. counters[rng. count]
69
75
Random. seed! (rng. rng, key)
70
- return set_counter! (rng. rng, ctr )
76
+ return set_counter! (rng. rng, rng . count )
71
77
end
72
78
79
+ """
80
+ save_state!(r::TracedRNG)
81
+
82
+ Track current key of the inner RNG
83
+ """
73
84
function save_state! (r:: TracedRNG{T} ) where {T}
74
- push! (r. keys, r. rng. key)
75
- return push! (r. counters, r. rng. ctr1)
85
+ return push! (r. keys, r. rng. key)
76
86
end
77
87
78
88
Base. copy (r:: TracedRNG{T} ) where {T} = TracedRNG (r. count, copy (r. rng), copy (r. keys))
0 commit comments