@@ -4,6 +4,7 @@ using Distributions
4
4
5
5
import Base. rand
6
6
import Random. seed!
7
+ import Random123: set_counter!
7
8
8
9
# Default RNG type for when nothing is specified
9
10
_BASE_RNG = Philox2x
29
30
Initialize TracedRNG with r as the inner RNG
30
31
"""
31
32
function TracedRNG (r:: Random123.AbstractR123 )
32
- set_counter! (r, 1 )
33
+ set_counter! (r, 0 )
33
34
return TracedRNG (1 , r, typeof (r. key)[])
34
35
end
35
36
@@ -48,42 +49,33 @@ Random.rng_native_52(rng::TracedRNG{U}) where {U} = Random.rng_native_52(rng.rng
48
49
Base. rand (rng:: TracedRNG{U} , :: Type{T} ) where {U,T} = Base. rand (rng. rng, T)
49
50
50
51
"""
51
- split(r::TracedRNG , n::Integer)
52
+ split(key::Integer , n::Integer)
52
53
53
- Split inner RNG into n new TracedRNG
54
+ Split key into n new keys
54
55
"""
55
- function split (r :: TracedRNG{T} , n:: Integer ) where {T}
56
- return map (i -> hash (r . rng . key, convert (UInt, r . rng . ctr1 + i)), 1 : n)
56
+ function split (key :: Integer , n:: Integer = 1 ) where {T}
57
+ return map (i -> hash (key, convert (UInt, i)), 1 : n)
57
58
end
58
59
59
60
"""
60
- seed!(r::TracedRNG, seed::Number)
61
-
62
- Set the key of the inner RNG as `seed`
63
- """
64
- function seed! (r:: TracedRNG{T} , seed) where {T}
65
- return seed! (r. rng, seed)
66
- end
67
-
68
- """
69
- load_state(r::TracedRNG, seed)
61
+ load_state!(r::TracedRNG, seed)
70
62
71
63
Load state from current model iteration. Random streams are now replayed
72
64
"""
73
- function load_state (rng:: TracedRNG{T} ) where {T}
65
+ function load_state! (rng:: TracedRNG{T} ) where {T}
74
66
key = rng. keys[rng. count]
75
67
Random. seed! (rng. rng, key)
76
- return set_counter! (rng. rng, rng . count )
68
+ return set_counter! (rng. rng, 0 )
77
69
end
78
70
79
71
"""
80
72
update_rng!(rng::TracedRNG)
81
73
82
74
Set key and counter of inner RNG to key and the running model step
83
75
"""
84
- function update_rng ! (rng:: TracedRNG{T} , key) where {T}
85
- seed! (rng, key)
86
- return set_counter! (rng. rng, rng . count )
76
+ function seed ! (rng:: TracedRNG{T} , key) where {T}
77
+ seed! (rng. rng , key)
78
+ return set_counter! (rng. rng, 0 )
87
79
end
88
80
89
81
"""
@@ -102,7 +94,11 @@ Base.copy(r::TracedRNG{T}) where {T} = TracedRNG(r.count, copy(r.rng), copy(r.ke
102
94
103
95
Set the counter of the TracedRNG, used to keep track of the current model step
104
96
"""
105
- set_count! (r:: TracedRNG , n:: Integer ) = r. count = n
97
+ set_counter! (r:: TracedRNG , n:: Integer ) = r. count = n
98
+
99
+ """
100
+ inc_counter!(r::TracedRNG, n::Integer=1)
106
101
107
- inc_count! (r:: TracedRNG , n:: Integer ) = r. count += n
108
- inc_count! (r:: TracedRNG ) = inc_count! (r, 1 )
102
+ Increase the model step counter by n
103
+ """
104
+ inc_counter! (r:: TracedRNG , n:: Integer = 1 ) = r. count += n
0 commit comments