Skip to content

Commit 0b9a696

Browse files
committed
Add optional rng argument to sampling functions
1 parent 49e0b69 commit 0b9a696

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "0.1.0"
44

55
[deps]
66
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
7+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
78
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
89

910
[compat]

src/AdaptiveRejectionSampling.jl

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ exponential. As the resulting extremely precise envelop adapts, the rejection ra
55
"""
66
module AdaptiveRejectionSampling
77
# ------------------------------
8+
using Random # Random stdlib
9+
# ------------------------------
810
using ForwardDiff # For automatic differentiation, no user nor approximate derivatives
911
using StatsBase # To include the basic sample from array function
1012
# ------------------------------
@@ -98,18 +100,24 @@ function add_segment!(e::Envelop, l::Line)
98100
end
99101

100102
"""
101-
sample_envelop(p::Envelop)
103+
sample_envelop(rng::AbstractRNG, e::Envelop)
104+
sample_envelop(e::Envelop)
102105
Samples an element from the density defined by the envelop `e` with it's exponential weights.
103106
See [`Envelop`](@Envelop) for details.
104107
"""
105-
function sample_envelop(e::Envelop)
108+
function sample_envelop(rng::AbstractRNG, e::Envelop)
106109
# Randomly select lines based on envelop weights
107-
i = sample(1:e.size, weights(e.weights))
110+
i = sample(rng, 1:e.size, weights(e.weights))
108111
a, b = e.lines[i].slope, e.lines[i].intercept
109112
# Use the inverse CDF method for sampling
110-
log(exp(-b) * rand() * e.weights[i] * a + exp(a * e.cutpoints[i])) / a
113+
log(exp(-b) * rand(rng) * e.weights[i] * a + exp(a * e.cutpoints[i])) / a
114+
end
115+
116+
function sample_envelop(e::Envelop)
117+
sample_envelop(Random.GLOBAL_RNG, e)
111118
end
112119

120+
113121
"""
114122
eval_envelop(e::Envelop, x::Float64)
115123
Eval point a point `x` in the piecewise linear function defined by `e`. Necessary for evaluating
@@ -208,18 +216,19 @@ struct RejectionSampler
208216
end
209217

210218
"""
219+
run_sampler!(rng::AbstractRNG, sampler::RejectionSampler, n::Int)
211220
run_sampler!(sampler::RejectionSampler, n::Int)
212221
It draws `n` iid samples of the objective function of `sampler`, and at each iteration it adapts the envelop
213222
of `sampler` by adding new segments to its envelop.
214223
"""
215-
function run_sampler!(s::RejectionSampler, n::Int)
224+
function run_sampler!(rng::AbstractRNG, s::RejectionSampler, n::Int)
216225
i = 0
217226
failed, max_failed = 0, trunc(Int, n / s.max_failed_rate)
218227
out = zeros(n)
219228
while i < n
220-
candidate = sample_envelop(s.envelop)
229+
candidate = sample_envelop(rng, s.envelop)
221230
acceptance_ratio = exp(s.objective.logf(candidate)) / eval_envelop(s.envelop, candidate)
222-
if rand() < acceptance_ratio
231+
if rand(rng) < acceptance_ratio
223232
i += 1
224233
out[i] = candidate
225234
else
@@ -234,4 +243,8 @@ function run_sampler!(s::RejectionSampler, n::Int)
234243
end
235244
out
236245
end
237-
end #
246+
247+
function run_sampler!(s::RejectionSampler, n::Int)
248+
run_sampler!(Random.GLOBAL_RNG, s, n)
249+
end
250+
end # module AdaptiveRejectionSampling

0 commit comments

Comments
 (0)