Skip to content

Commit 5690c9d

Browse files
authored
Merge pull request #23 from devmotion/suggestions
Some suggestions
2 parents 02a6cb9 + 6c6fc70 commit 5690c9d

File tree

8 files changed

+108
-146
lines changed

8 files changed

+108
-146
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.4.1"
3+
version = "0.5.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,34 +75,35 @@ supported methods are `Array{Proposal}`, `Proposal`, and `NamedTuple{Proposal}`.
7575
```julia
7676
# Provide a univariate proposal.
7777
m1 = DensityModel(x -> logpdf(Normal(x,1), 1.0))
78-
p1 = Proposal(Static(), Normal(0,1))
78+
p1 = StaticProposal(Normal(0,1))
7979
c1 = sample(m1, MetropolisHastings(p1), 100; chain_type=Vector{NamedTuple})
8080

8181
# Draw from a vector of distributions.
8282
m2 = DensityModel(x -> logpdf(Normal(x[1], x[2]), 1.0))
83-
p2 = Proposal(Static(), [Normal(0,1), InverseGamma(2,3)])
83+
p2 = StaticProposal([Normal(0,1), InverseGamma(2,3)])
8484
c2 = sample(m2, MetropolisHastings(p2), 100; chain_type=Vector{NamedTuple})
8585

8686
# Draw from a `NamedTuple` of distributions.
8787
m3 = DensityModel(x -> logpdf(Normal(x.a, x.b), 1.0))
88-
p3 = (a=Proposal(Static(), Normal(0,1)), b=Proposal(Static(), InverseGamma(2,3)))
88+
p3 = (a=StaticProposal(Normal(0,1)), b=StaticProposal(InverseGamma(2,3)))
8989
c3 = sample(m3, MetropolisHastings(p3), 100; chain_type=Vector{NamedTuple})
9090

9191
# Draw from a functional proposal.
9292
m4 = DensityModel(x -> logpdf(Normal(x,1), 1.0))
93-
p4 = Proposal(Static(), (x=1.0) -> Normal(x, 1))
93+
p4 = StaticProposal((x=1.0) -> Normal(x, 1))
9494
c4 = sample(m4, MetropolisHastings(p4), 100; chain_type=Vector{NamedTuple})
9595
```
9696

9797
## Static vs. Random Walk
9898

9999
Currently there are only two methods of inference available. Static MH simply draws from the prior, with no
100100
conditioning on the previous sample. Random walk will add the proposal to the previously observed value.
101-
If you are constructing a `Proposal` by hand, you can determine whether the proposal is `Static` or `RandomWalk` using
101+
If you are constructing a `Proposal` by hand, you can determine whether the proposal is a
102+
`StaticProposal` or a `RandomWalkProposal` using
102103

103104
```julia
104-
static_prop = Proposal(Static(), Normal(0,1))
105-
rw_prop = Proposal(RandomWalk(), Normal(0,1))
105+
static_prop = StaticProposal(Normal(0,1))
106+
rw_prop = RandomWalkProposal(Normal(0,1))
106107
```
107108

108109
Different methods are easily composeable. One parameter can be static and another can be a random walk,

src/AdvancedMH.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,17 @@ import AbstractMCMC
55
using Distributions
66
using Requires
77

8-
using Random
8+
import Random
99

1010
# Exports
11-
export MetropolisHastings, DensityModel, RWMH, StaticMH, Proposal, Static, RandomWalk
11+
export MetropolisHastings, DensityModel, RWMH, StaticMH, StaticProposal, RandomWalkProposal
1212

1313
# Reexports
1414
using AbstractMCMC: sample, psample
1515
export sample, psample
1616

1717
# Abstract type for MH-style samplers.
1818
abstract type Metropolis <: AbstractMCMC.AbstractSampler end
19-
abstract type ProposalStyle end
20-
21-
struct RandomWalk <: ProposalStyle end
22-
struct Static <: ProposalStyle end
2319

2420
# Define a model type. Stores the log density function and the data to
2521
# evaluate the log density on.
@@ -55,7 +51,7 @@ logdensity(model::DensityModel, t::Transition) = t.lp
5551

5652
# A basic chains constructor that works with the Transition struct we defined.
5753
function AbstractMCMC.bundle_samples(
58-
rng::AbstractRNG,
54+
rng::Random.AbstractRNG,
5955
model::DensityModel,
6056
s::Metropolis,
6157
N::Integer,
@@ -68,7 +64,7 @@ function AbstractMCMC.bundle_samples(
6864
end
6965

7066
function AbstractMCMC.bundle_samples(
71-
rng::AbstractRNG,
67+
rng::Random.AbstractRNG,
7268
model::DensityModel,
7369
s::Metropolis,
7470
N::Integer,

src/mcmcchains-connect.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import .MCMCChains: Chains
22

33
# A basic chains constructor that works with the Transition struct we defined.
44
function AbstractMCMC.bundle_samples(
5-
rng::AbstractRNG,
5+
rng::Random.AbstractRNG,
66
model::DensityModel,
77
s::Metropolis,
88
N::Integer,

src/mh-core.jl

Lines changed: 44 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@ x = (a = 1.0, b=3.8)
1212
The proposal would be
1313
1414
```julia
15-
proposal = (a=Proposal(Static(), Normal(0,1)), b=Proposal(Static(), Normal(0,1)))
15+
proposal = (a=StaticProposal(Normal(0,1)), b=StaticProposal(Normal(0,1)))
1616
````
1717
18-
Other allowed proposal styles are
18+
Other allowed proposals are
1919
2020
```
21-
p1 = Proposal(Static(), Normal(0,1))
22-
p2 = Proposal(Static(), [Normal(0,1), InverseGamma(2,3)])
23-
p3 = Proposal(Static(), (a=Normal(0,1), b=InverseGamma(2,3)))
24-
p4 = Proposal(Static(), (x=1.0) -> Normal(x, 1))
21+
p1 = StaticProposal(Normal(0,1))
22+
p2 = StaticProposal([Normal(0,1), InverseGamma(2,3)])
23+
p3 = StaticProposal(a=Normal(0,1), b=InverseGamma(2,3))
24+
p4 = StaticProposal((x=1.0) -> Normal(x, 1))
2525
```
2626
2727
The sampler is constructed using
@@ -41,98 +41,92 @@ used if `chain_type=Chains`.
4141
types are `chain_type=Chains` if `MCMCChains` is imported, or
4242
`chain_type=StructArray` if `StructArrays` is imported.
4343
"""
44-
mutable struct MetropolisHastings{D} <: Metropolis
45-
proposal :: D
44+
struct MetropolisHastings{D} <: Metropolis
45+
proposal::D
4646
end
4747

48-
StaticMH(d) = MetropolisHastings(Proposal(Static(), d))
49-
RWMH(d) = MetropolisHastings(Proposal(RandomWalk(), d))
48+
StaticMH(d) = MetropolisHastings(StaticProposal(d))
49+
RWMH(d) = MetropolisHastings(RandomWalkProposal(d))
50+
51+
# default function without RNG
52+
propose(spl::MetropolisHastings, args...) = propose(Random.GLOBAL_RNG, spl, args...)
5053

5154
# Propose from a vector of proposals
5255
function propose(
56+
rng::Random.AbstractRNG,
5357
spl::MetropolisHastings{<:AbstractArray},
5458
model::DensityModel
5559
)
56-
proposal = map(i -> propose(spl.proposal[i], model), 1:length(spl.proposal))
60+
proposal = map(p -> propose(rng, p, model), spl.proposal)
5761
return Transition(model, proposal)
5862
end
5963

6064
function propose(
65+
rng::Random.AbstractRNG,
6166
spl::MetropolisHastings{<:AbstractArray},
6267
model::DensityModel,
6368
params_prev::Transition
6469
)
65-
proposal = map(i -> propose(spl.proposal[i], model, params_prev.params[i]), 1:length(spl.proposal))
70+
proposal = map(spl.proposal, params_prev.params) do p, params
71+
propose(rng, p, model, params)
72+
end
6673
return Transition(model, proposal)
6774
end
6875

6976
# Make a proposal from one Proposal struct.
7077
function propose(
78+
rng::Random.AbstractRNG,
7179
spl::MetropolisHastings{<:Proposal},
7280
model::DensityModel
7381
)
74-
proposal = propose(spl.proposal, model)
82+
proposal = propose(rng, spl.proposal, model)
7583
return Transition(model, proposal)
7684
end
7785

7886
function propose(
87+
rng::Random.AbstractRNG,
7988
spl::MetropolisHastings{<:Proposal},
8089
model::DensityModel,
8190
params_prev::Transition
8291
)
83-
proposal = propose(spl.proposal, model, params_prev.params)
92+
proposal = propose(rng, spl.proposal, model, params_prev.params)
8493
return Transition(model, proposal)
8594
end
8695

8796
# Make a proposal from a NamedTuple of Proposal.
8897
function propose(
98+
rng::Random.AbstractRNG,
8999
spl::MetropolisHastings{<:NamedTuple},
90100
model::DensityModel
91101
)
92-
proposal = _propose(spl.proposal, model)
93-
return Transition(model, proposal)
94-
end
95-
96-
@generated function _propose(
97-
proposals::NamedTuple{names},
98-
model::DensityModel
99-
) where {names}
100-
expr = Expr(:tuple)
101-
map(names) do f
102-
push!(expr.args, Expr(:(=), f, :(propose(proposals.$f, model)) ))
102+
proposal = map(spl.proposal) do p
103+
propose(rng, p, model)
103104
end
104-
return expr
105+
return Transition(model, proposal)
105106
end
106107

107108
function propose(
109+
rng::Random.AbstractRNG,
108110
spl::MetropolisHastings{<:NamedTuple},
109111
model::DensityModel,
110112
params_prev::Transition
111113
)
112-
proposal = _propose(spl.proposal, model, params_prev.params)
113-
return Transition(model, proposal)
114-
end
115-
116-
@generated function _propose(
117-
proposals::NamedTuple{names},
118-
model::DensityModel,
119-
params_prev::NamedTuple
120-
) where {names}
121-
expr = Expr(:tuple)
122-
map(names) do f
123-
push!(expr.args, Expr(:(=), f, :(propose(proposals.$f, model, params_prev.$f)) ))
114+
proposal = map(spl.proposal, params_prev.params) do p, params
115+
propose(rng, p, model, params)
124116
end
125-
return expr
117+
return Transition(model, proposal)
126118
end
127119

128-
129120
# Evaluate the likelihood of t conditional on t_cond.
130121
function q(
131122
spl::MetropolisHastings{<:AbstractArray},
132123
t::Transition,
133124
t_cond::Transition
134125
)
135-
return sum(map(i -> q(spl.proposal[i], t.params[i], t_cond.params[i]), 1:length(spl.proposal)))
126+
# mapreduce with multiple iterators requires Julia 1.2 or later
127+
return mapreduce(+, 1:length(spl.proposal)) do i
128+
q(spl.proposal[i], t.params[i], t_cond.params[i])
129+
end
136130
end
137131

138132
function q(
@@ -148,21 +142,17 @@ function q(
148142
t::Transition,
149143
t_cond::Transition
150144
)
151-
ks = keys(t.params)
152-
total = 0.0
153-
154-
for k in ks
155-
total += q(spl.proposal[k], t.params[k], t_cond.params[k])
145+
# mapreduce with multiple iterators requires Julia 1.2 or later
146+
return mapreduce(+, keys(t.params)) do k
147+
q(spl.proposal[k], t.params[k], t_cond.params[k])
156148
end
157-
158-
return total
159149
end
160150

161151
# Define the first step! function, which is called at the
162152
# beginning of sampling. Return the initial parameter used
163153
# to define the sampler.
164154
function AbstractMCMC.step!(
165-
rng::AbstractRNG,
155+
rng::Random.AbstractRNG,
166156
model::DensityModel,
167157
spl::MetropolisHastings,
168158
N::Integer,
@@ -171,7 +161,7 @@ function AbstractMCMC.step!(
171161
kwargs...
172162
)
173163
if init_params === nothing
174-
return propose(spl, model)
164+
return propose(rng, spl, model)
175165
else
176166
return Transition(model, init_params)
177167
end
@@ -181,22 +171,22 @@ end
181171
# either a new proposal (if accepted) or the previous proposal
182172
# (if not accepted).
183173
function AbstractMCMC.step!(
184-
rng::AbstractRNG,
174+
rng::Random.AbstractRNG,
185175
model::DensityModel,
186176
spl::MetropolisHastings,
187177
::Integer,
188178
params_prev::Transition;
189179
kwargs...
190180
)
191181
# Generate a new proposal.
192-
params = propose(spl, model, params_prev)
182+
params = propose(rng, spl, model, params_prev)
193183

194184
# Calculate the log acceptance probability.
195-
α = logdensity(model, params) - logdensity(model, params_prev) +
185+
logα = logdensity(model, params) - logdensity(model, params_prev) +
196186
q(spl, params_prev, params) - q(spl, params, params_prev)
197187

198188
# Decide whether to return the previous params or the new one.
199-
if log(rand(rng)) < min(0.0, α)
189+
if -Random.randexp(rng) < logα
200190
return params
201191
else
202192
return params_prev

0 commit comments

Comments
 (0)