Skip to content

Commit 784f87b

Browse files
committed
Add support for custom RNG
1 parent 77309cd commit 784f87b

File tree

5 files changed

+39
-23
lines changed

5 files changed

+39
-23
lines changed

src/AdvancedMH.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import AbstractMCMC
55
using Distributions
66
using Requires
77

8-
using Random
8+
import Random
99

1010
# Exports
1111
export MetropolisHastings, DensityModel, RWMH, StaticMH, StaticProposal, RandomWalkProposal
@@ -51,7 +51,7 @@ logdensity(model::DensityModel, t::Transition) = t.lp
5151

5252
# A basic chains constructor that works with the Transition struct we defined.
5353
function AbstractMCMC.bundle_samples(
54-
rng::AbstractRNG,
54+
rng::Random.AbstractRNG,
5555
model::DensityModel,
5656
s::Metropolis,
5757
N::Integer,
@@ -64,7 +64,7 @@ function AbstractMCMC.bundle_samples(
6464
end
6565

6666
function AbstractMCMC.bundle_samples(
67-
rng::AbstractRNG,
67+
rng::Random.AbstractRNG,
6868
model::DensityModel,
6969
s::Metropolis,
7070
N::Integer,

src/mcmcchains-connect.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import .MCMCChains: Chains
22

33
# A basic chains constructor that works with the Transition struct we defined.
44
function AbstractMCMC.bundle_samples(
5-
rng::AbstractRNG,
5+
rng::Random.AbstractRNG,
66
model::DensityModel,
77
s::Metropolis,
88
N::Integer,

src/mh-core.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,62 +48,71 @@ end
4848
StaticMH(d) = MetropolisHastings(StaticProposal(d))
4949
RWMH(d) = MetropolisHastings(RandomWalkProposal(d))
5050

51+
# default function without RNG
52+
propose(spl::MetropolisHastings, args...) = propose(Random.GLOBAL_RNG, spl, args...)
53+
5154
# Propose from a vector of proposals
5255
function propose(
56+
rng::Random.AbstractRNG,
5357
spl::MetropolisHastings{<:AbstractArray},
5458
model::DensityModel
5559
)
56-
proposal = map(p -> propose(p, model), spl.proposal)
60+
proposal = map(p -> propose(rng, p, model), spl.proposal)
5761
return Transition(model, proposal)
5862
end
5963

6064
function propose(
65+
rng::Random.AbstractRNG,
6166
spl::MetropolisHastings{<:AbstractArray},
6267
model::DensityModel,
6368
params_prev::Transition
6469
)
6570
proposal = map(spl.proposal, params_prev.params) do p, params
66-
propose(p, model, params)
71+
propose(rng, p, model, params)
6772
end
6873
return Transition(model, proposal)
6974
end
7075

7176
# Make a proposal from one Proposal struct.
7277
function propose(
78+
rng::Random.AbstractRNG,
7379
spl::MetropolisHastings{<:Proposal},
7480
model::DensityModel
7581
)
76-
proposal = propose(spl.proposal, model)
82+
proposal = propose(rng, spl.proposal, model)
7783
return Transition(model, proposal)
7884
end
7985

8086
function propose(
87+
rng::Random.AbstractRNG,
8188
spl::MetropolisHastings{<:Proposal},
8289
model::DensityModel,
8390
params_prev::Transition
8491
)
85-
proposal = propose(spl.proposal, model, params_prev.params)
92+
proposal = propose(rng, spl.proposal, model, params_prev.params)
8693
return Transition(model, proposal)
8794
end
8895

8996
# Make a proposal from a NamedTuple of Proposal.
9097
function propose(
98+
rng::Random.AbstractRNG,
9199
spl::MetropolisHastings{<:NamedTuple},
92100
model::DensityModel
93101
)
94102
proposal = map(spl.proposal) do p
95-
propose(p, model)
103+
propose(rng, p, model)
96104
end
97105
return Transition(model, proposal)
98106
end
99107

100108
function propose(
109+
rng::Random.AbstractRNG,
101110
spl::MetropolisHastings{<:NamedTuple},
102111
model::DensityModel,
103112
params_prev::Transition
104113
)
105114
proposal = map(spl.proposal, params_prev.params) do p, params
106-
propose(p, model, params)
115+
propose(rng, p, model, params)
107116
end
108117
return Transition(model, proposal)
109118
end
@@ -143,7 +152,7 @@ end
143152
# beginning of sampling. Return the initial parameter used
144153
# to define the sampler.
145154
function AbstractMCMC.step!(
146-
rng::AbstractRNG,
155+
rng::Random.AbstractRNG,
147156
model::DensityModel,
148157
spl::MetropolisHastings,
149158
N::Integer,
@@ -152,7 +161,7 @@ function AbstractMCMC.step!(
152161
kwargs...
153162
)
154163
if init_params === nothing
155-
return propose(spl, model)
164+
return propose(rng, spl, model)
156165
else
157166
return Transition(model, init_params)
158167
end
@@ -162,15 +171,15 @@ end
162171
# either a new proposal (if accepted) or the previous proposal
163172
# (if not accepted).
164173
function AbstractMCMC.step!(
165-
rng::AbstractRNG,
174+
rng::Random.AbstractRNG,
166175
model::DensityModel,
167176
spl::MetropolisHastings,
168177
::Integer,
169178
params_prev::Transition;
170179
kwargs...
171180
)
172181
# Generate a new proposal.
173-
params = propose(spl, model, params_prev)
182+
params = propose(rng, spl, model, params_prev)
174183

175184
# Calculate the log acceptance probability.
176185
logα = logdensity(model, params) - logdensity(model, params_prev) +

src/proposal.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@ struct RandomWalkProposal{P} <: Proposal{P}
99
end
1010

1111
# Random draws
12-
Base.rand(p::Proposal{<:Distribution}) = rand(p.proposal)
13-
Base.rand(p::Proposal{<:AbstractArray}) = map(rand, p.proposal)
12+
Base.rand(p::Proposal, args...) = rand(Random.GLOBAL_RNG, p, args...)
13+
Base.rand(rng::Random.AbstractRNG, p::Proposal{<:Distribution}) = rand(rng, p.proposal)
14+
function Base.rand(rng::Random.AbstractRNG, p::Proposal{<:AbstractArray})
15+
return map(x -> rand(rng, x), p.proposal)
16+
end
1417

1518
# Densities
1619
Distributions.logpdf(p::Proposal{<:Distribution}, v) = logpdf(p.proposal, v)
@@ -23,16 +26,17 @@ end
2326
# Random Walk #
2427
###############
2528

26-
function propose(p::RandomWalkProposal, m::DensityModel)
27-
return propose(StaticProposal(p.proposal), m)
29+
function propose(rng::Random.AbstractRNG, p::RandomWalkProposal, m::DensityModel)
30+
return propose(rng, StaticProposal(p.proposal), m)
2831
end
2932

3033
function propose(
34+
rng::Random.AbstractRNG,
3135
proposal::RandomWalkProposal{<:Union{Distribution,AbstractArray}},
3236
model::DensityModel,
3337
t
3438
)
35-
return t + rand(proposal)
39+
return t + rand(rng, proposal)
3640
end
3741

3842
function q(
@@ -48,11 +52,12 @@ end
4852
##########
4953

5054
function propose(
55+
rng::Random.AbstractRNG,
5156
proposal::StaticProposal{<:Union{Distribution,AbstractArray}},
5257
model::DensityModel,
5358
t=nothing
5459
)
55-
return rand(proposal)
60+
return rand(rng, proposal)
5661
end
5762

5863
function q(
@@ -76,18 +81,20 @@ for T in (StaticProposal, RandomWalkProposal)
7681
end
7782

7883
function propose(
84+
rng::Random.AbstractRNG,
7985
proposal::Proposal{<:Function},
8086
model::DensityModel
8187
)
82-
return propose(proposal(), model)
88+
return propose(rng, proposal(), model)
8389
end
8490

8591
function propose(
92+
rng::Random.AbstractRNG,
8693
proposal::Proposal{<:Function},
8794
model::DensityModel,
8895
t
8996
)
90-
return propose(proposal(t), model)
97+
return propose(rng, proposal(t), model)
9198
end
9299

93100
function q(

src/structarray-connect.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import .StructArrays: StructArray
22

33
# A basic chains constructor that works with the Transition struct we defined.
44
function AbstractMCMC.bundle_samples(
5-
rng::AbstractRNG,
5+
rng::Random.AbstractRNG,
66
model::DensityModel,
77
s::Metropolis,
88
N::Integer,

0 commit comments

Comments
 (0)