|
1 |
| -using Random123 |
2 |
| -using Random |
3 |
| -using Distributions |
4 |
| - |
5 |
| -import Base.rand |
6 |
| -import Random.seed! |
7 |
| -import Random123: set_counter! |
8 |
| - |
9 | 1 | # Default RNG type for when nothing is specified
|
10 |
| -_BASE_RNG = Philox2x |
| 2 | +const _BASE_RNG = Random123.Philox2x |
11 | 3 |
|
12 | 4 | """
|
13 |
| - TracedRNG{R,T} |
| 5 | + TracedRNG{R,N,T} |
14 | 6 |
|
15 | 7 | Wrapped random number generator from Random123 to keep track of random streams during model evaluation
|
16 | 8 | """
|
17 |
| -mutable struct TracedRNG{T} <: |
18 |
| - Random.AbstractRNG where {T<:(Random123.AbstractR123{R} where {R})} |
| 9 | +mutable struct TracedRNG{R,N,T<:Random123.AbstractR123{R}} <: Random.AbstractRNG |
19 | 10 | "Model step counter"
|
20 | 11 | count::Int
|
21 | 12 | "Inner RNG"
|
22 | 13 | rng::T
|
23 | 14 | "Array of keys"
|
24 |
| - keys |
| 15 | + keys::Array{R,N} |
25 | 16 | end
|
26 | 17 |
|
27 | 18 | """
|
28 |
| - TracedRNG(r::Random123.AbstractR123) |
29 |
| -
|
30 |
| -Initialize TracedRNG with r as the inner RNG |
| 19 | + TracedRNG(r::Random123.AbstractR123=AdvancedPS._BASE_RNG()) |
| 20 | +Create a `TracedRNG` with `r` as the inner RNG. |
31 | 21 | """
|
32 |
| -function TracedRNG(r::Random123.AbstractR123) |
33 |
| - set_counter!(r, 0) |
| 22 | +function TracedRNG(r::Random123.AbstractR123=_BASE_RNG()) |
| 23 | + Random123.set_counter!(r, 0) |
34 | 24 | return TracedRNG(1, r, typeof(r.key)[])
|
35 | 25 | end
|
36 | 26 |
|
37 |
| -""" |
38 |
| - TracedRNG() |
39 |
| -
|
40 |
| -Create a default TracedRNG |
41 |
| -""" |
42 |
| -function TracedRNG() |
43 |
| - r = _BASE_RNG() |
44 |
| - return TracedRNG(r) |
45 |
| -end |
46 |
| - |
47 | 27 | # Connect to the Random API
|
48 |
| -Random.rng_native_52(rng::TracedRNG{U}) where {U} = Random.rng_native_52(rng.rng) |
49 |
| -Base.rand(rng::TracedRNG{U}, ::Type{T}) where {U,T} = Base.rand(rng.rng, T) |
| 28 | +Random.rng_native_52(rng::TracedRNG) = Random.rng_native_52(rng.rng) |
| 29 | +Base.rand(rng::TracedRNG, ::Type{T}) where {T} = Base.rand(rng.rng, T) |
50 | 30 |
|
51 | 31 | """
|
52 |
| - split(key::Integer, n::Integer) |
| 32 | + split(key::Integer, n::Integer=1) |
53 | 33 |
|
54 |
| -Split key into n new keys |
| 34 | +Split `key` into `n` new keys |
55 | 35 | """
|
56 |
| -function split(key::Integer, n::Integer=1) where {T} |
57 |
| - return map(i -> hash(key, convert(UInt, i)), 1:n) |
| 36 | +function split(key::Integer, n::Integer=1) |
| 37 | + return [hash(key, i) for i in UInt(1):UInt(n)] |
58 | 38 | end
|
59 | 39 |
|
60 | 40 | """
|
61 |
| - load_state!(r::TracedRNG, seed) |
| 41 | + load_state!(r::TracedRNG) |
62 | 42 |
|
63 | 43 | Load state from current model iteration. Random streams are now replayed
|
64 | 44 | """
|
65 |
| -function load_state!(rng::TracedRNG{T}) where {T} |
| 45 | +function load_state!(rng::TracedRNG) |
66 | 46 | key = rng.keys[rng.count]
|
67 | 47 | Random.seed!(rng.rng, key)
|
68 |
| - return set_counter!(rng.rng, 0) |
| 48 | + return Random123.set_counter!(rng.rng, 0) |
69 | 49 | end
|
70 | 50 |
|
71 | 51 | """
|
72 | 52 | update_rng!(rng::TracedRNG)
|
73 | 53 |
|
74 |
| -Set key and counter of inner RNG to key and the running model step |
| 54 | +Set key and counter of inner rng in `rng` to `key` and the running model step to 0 |
75 | 55 | """
|
76 |
| -function seed!(rng::TracedRNG{T}, key) where {T} |
77 |
| - seed!(rng.rng, key) |
78 |
| - return set_counter!(rng.rng, 0) |
| 56 | +function Random.seed!(rng::TracedRNG, key) |
| 57 | + Random.seed!(rng.rng, key) |
| 58 | + return Random123.set_counter!(rng.rng, 0) |
79 | 59 | end
|
80 | 60 |
|
81 | 61 | """
|
82 | 62 | save_state!(r::TracedRNG)
|
83 | 63 |
|
84 |
| -Track current key of the inner RNG |
| 64 | +Add current key of the inner rng in `r` to `keys`. |
85 | 65 | """
|
86 |
| -function save_state!(r::TracedRNG{T}) where {T} |
| 66 | +function save_state!(r::TracedRNG) |
87 | 67 | return push!(r.keys, r.rng.key)
|
88 | 68 | end
|
89 | 69 |
|
90 |
| -Base.copy(r::TracedRNG{T}) where {T} = TracedRNG(r.count, copy(r.rng), copy(r.keys)) |
| 70 | +Base.copy(r::TracedRNG) = TracedRNG(r.count, copy(r.rng), deepcopy(r.keys)) |
91 | 71 |
|
92 | 72 | """
|
93 |
| - set_count!(r::TracedRNG, n::Integer) |
| 73 | + set_counter!(r::TracedRNG, n::Integer) |
94 | 74 |
|
95 |
| -Set the counter of the TracedRNG, used to keep track of the current model step |
| 75 | +Set the counter of the inner rng in `r`, used to keep track of the current model step |
96 | 76 | """
|
97 |
| -set_counter!(r::TracedRNG, n::Integer) = r.count = n |
| 77 | +Random123.set_counter!(r::TracedRNG, n::Integer) = r.count = n |
98 | 78 |
|
99 | 79 | """
|
100 | 80 | inc_counter!(r::TracedRNG, n::Integer=1)
|
101 | 81 |
|
102 |
| -Increase the model step counter by n |
| 82 | +Increase the model step counter by `n` |
103 | 83 | """
|
104 | 84 | inc_counter!(r::TracedRNG, n::Integer=1) = r.count += n
|
0 commit comments