@@ -5,6 +5,8 @@ exponential. As the resulting extremely precise envelop adapts, the rejection ra
55"""
66module AdaptiveRejectionSampling
77# ------------------------------
8+ using Random # Random stdlib
9+ # ------------------------------
810using ForwardDiff # For automatic differentiation, no user nor approximate derivatives
911using StatsBase # To include the basic sample from array function
1012# ------------------------------
@@ -98,18 +100,24 @@ function add_segment!(e::Envelop, l::Line)
98100end
99101
100102"""
101- sample_envelop(p::Envelop)
103+ sample_envelop(rng::AbstractRNG, e::Envelop)
104+ sample_envelop(e::Envelop)
102105Samples an element from the density defined by the envelop `e` with it's exponential weights.
103106See [`Envelop`](@Envelop) for details.
104107"""
105- function sample_envelop (e:: Envelop )
108+ function sample_envelop (rng :: AbstractRNG , e:: Envelop )
106109 # Randomly select lines based on envelop weights
107- i = sample (1 : e. size, weights (e. weights))
110+ i = sample (rng, 1 : e. size, weights (e. weights))
108111 a, b = e. lines[i]. slope, e. lines[i]. intercept
109112 # Use the inverse CDF method for sampling
110- log (exp (- b) * rand () * e. weights[i] * a + exp (a * e. cutpoints[i])) / a
113+ log (exp (- b) * rand (rng) * e. weights[i] * a + exp (a * e. cutpoints[i])) / a
114+ end
115+
116+ function sample_envelop (e:: Envelop )
117+ sample_envelop (Random. GLOBAL_RNG, e)
111118end
112119
120+
113121"""
114122 eval_envelop(e::Envelop, x::Float64)
115123Eval point a point `x` in the piecewise linear function defined by `e`. Necessary for evaluating
@@ -208,18 +216,19 @@ struct RejectionSampler
208216end
209217
210218"""
219+ run_sampler!(rng::AbstractRNG, sampler::RejectionSampler, n::Int)
211220 run_sampler!(sampler::RejectionSampler, n::Int)
212221It draws `n` iid samples of the objective function of `sampler`, and at each iteration it adapts the envelop
213222of `sampler` by adding new segments to its envelop.
214223"""
215- function run_sampler! (s:: RejectionSampler , n:: Int )
224+ function run_sampler! (rng :: AbstractRNG , s:: RejectionSampler , n:: Int )
216225 i = 0
217226 failed, max_failed = 0 , trunc (Int, n / s. max_failed_rate)
218227 out = zeros (n)
219228 while i < n
220- candidate = sample_envelop (s. envelop)
229+ candidate = sample_envelop (rng, s. envelop)
221230 acceptance_ratio = exp (s. objective. logf (candidate)) / eval_envelop (s. envelop, candidate)
222- if rand () < acceptance_ratio
231+ if rand (rng ) < acceptance_ratio
223232 i += 1
224233 out[i] = candidate
225234 else
@@ -234,4 +243,8 @@ function run_sampler!(s::RejectionSampler, n::Int)
234243 end
235244 out
236245end
237- end #
246+
247+ function run_sampler! (s:: RejectionSampler , n:: Int )
248+ run_sampler! (Random. GLOBAL_RNG, s, n)
249+ end
250+ end # module AdaptiveRejectionSampling
0 commit comments