Skip to content

Commit aaeac98

Browse files
committed
Make AdvancedMH way cooler
1 parent 0a931d5 commit aaeac98

File tree

8 files changed

+332
-158
lines changed

8 files changed

+332
-158
lines changed

src/AdvancedMH.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,19 @@ using Requires
77
using Distributions
88

99
# Import specific functions and types to use or overload.
10-
import AbstractMCMC: step!, AbstractSampler, AbstractTransition, transition_type, bundle_samples
10+
import AbstractMCMC: step!, AbstractSampler, AbstractTransition,
11+
transition_type, bundle_samples
1112

1213
# Exports
13-
export MetropolisHastings, DensityModel, sample, psample, RWMH, StaticMH
14+
export MetropolisHastings, DensityModel, sample, psample, RWMH, StaticMH, Proposal
1415

1516
# Abstract type for MH-style samplers.
1617
abstract type Metropolis <: AbstractSampler end
1718
abstract type ProposalStyle end
1819

20+
struct RandomWalk <: ProposalStyle end
21+
struct Static <: ProposalStyle end
22+
1923
# Define a model type. Stores the log density function and the data to
2024
# evaluate the log density on.
2125
"""
@@ -57,7 +61,7 @@ function bundle_samples(
5761
model::DensityModel,
5862
s::Metropolis,
5963
N::Integer,
60-
ts::Vector{T};
64+
ts::Type{Any};
6165
param_names=missing,
6266
kwargs...
6367
) where {ModelType<:AbstractModel, T<:AbstractTransition}
@@ -70,13 +74,13 @@ function bundle_samples(
7074
s::Metropolis,
7175
N::Integer,
7276
ts::Vector{T},
73-
chain_type::Type{NamedTuple};
77+
chain_type::Type{Vector{NamedTuple}};
7478
param_names=missing,
7579
kwargs...
7680
) where {ModelType<:AbstractModel, T<:AbstractTransition}
7781
# Check if we received any parameter names.
7882
if ismissing(param_names)
79-
param_names = ["param_$i" for i in 1:length(s.init_params)]
83+
param_names = ["param_$i" for i in 1:length(keys(ts[1].params))]
8084
else
8185
# Deepcopy to be thread safe.
8286
param_names = deepcopy(param_names)
@@ -85,8 +89,8 @@ function bundle_samples(
8589
push!(param_names, "lp")
8690

8791
# Turn all the transitions into a vector-of-NamedTuple.
88-
keys = tuple(Symbol.(param_names)...)
89-
nts = [NamedTuple{keys}(tuple(t.params..., t.lp)) for t in ts]
92+
ks = tuple(Symbol.(param_names)...)
93+
nts = [NamedTuple{ks}(tuple(t.params..., t.lp)) for t in ts]
9094

9195
return nts
9296
end
@@ -97,8 +101,7 @@ function __init__()
97101
end
98102

99103
# Include inference methods.
104+
include("proposal.jl")
100105
include("mh-core.jl")
101-
include("rwmh.jl")
102-
include("staticmh.jl")
103106

104107
end # module AdvancedMH

src/mcmcchains-connect.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import MCMCChains: Chains
1+
import .MCMCChains: Chains
22

33
# A basic chains constructor that works with the Transition struct we defined.
44
function bundle_samples(

src/mh-core.jl

Lines changed: 140 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,150 @@
11
"""
2-
MetropolisHastings{P<:ProposalStyle, T, D}
2+
MetropolisHastings{D}
33
4-
Fields:
4+
`MetropolisHastings` has one field, `proposal`.
5+
`proposal` is a `Proposal`, `NamedTuple` of `Proposal`, or `Array{Proposal}` in the shape of your data.
6+
For example, if you wanted the sampler to return a `NamedTuple` with shape
57
6-
- `init_params` is the vector form of the parameters needed for the likelihood function.
7-
- `proposal` is a function that dynamically constructs a conditional distribution.
8+
```julia
9+
x = (a = 1.0, b=3.8)
10+
```
811
9-
Example:
12+
The proposal would be
1013
1114
```julia
12-
MetropolisHastings([0.0, 0.0], x -> MvNormal(x, 1.0))
15+
proposal = (a=Proposal(Static(), Normal(0,1)), b=Proposal(Static(), Normal(0,1)))
1316
````
17+
18+
Other allowed proposal styles are
19+
20+
```
21+
p1 = Proposal(Static(), Normal(0,1))
22+
p2 = Proposal(Static(), [Normal(0,1), InverseGamma(2,3)])
23+
p3 = Proposal(Static(), (a=Normal(0,1), b=InverseGamma(2,3)))
24+
p4 = Proposal(Static(), (x=1.0) -> Normal(x, 1))
25+
```
26+
27+
The sampler is constructed using
28+
29+
```julia
30+
spl = MetropolisHastings(proposal)
31+
```
1432
"""
15-
mutable struct MetropolisHastings{P<:ProposalStyle, D, T} <: Metropolis
16-
proposal_type :: P
33+
mutable struct MetropolisHastings{D} <: Metropolis
1734
proposal :: D
18-
init_params :: T
35+
end
36+
37+
StaticMH(d) = MetropolisHastings(Proposal(Static(), d))
38+
RWMH(d) = MetropolisHastings(Proposal(RandomWalk(), d))
39+
40+
# Propose from a vector of proposals
41+
function propose(
42+
spl::MetropolisHastings{<:AbstractArray},
43+
model::DensityModel
44+
)
45+
proposal = map(i -> propose(spl.proposal[i], model), 1:length(spl.proposal))
46+
return Transition(model, proposal)
47+
end
48+
49+
function propose(
50+
spl::MetropolisHastings{<:AbstractArray},
51+
model::DensityModel,
52+
params_prev::Transition
53+
)
54+
proposal = map(i -> propose(spl.proposal[i], model, params_prev.params[i]), 1:length(spl.proposal))
55+
return Transition(model, proposal)
56+
end
57+
58+
# Make a proposal from one Proposal struct.
59+
function propose(
60+
spl::MetropolisHastings{<:Proposal},
61+
model::DensityModel
62+
)
63+
proposal = propose(spl.proposal, model)
64+
return Transition(model, proposal)
65+
end
66+
67+
function propose(
68+
spl::MetropolisHastings{<:Proposal},
69+
model::DensityModel,
70+
params_prev::Transition
71+
)
72+
proposal = propose(spl.proposal, model, params_prev.params)
73+
return Transition(model, proposal)
74+
end
75+
76+
# Make a proposal from a NamedTuple of Proposal.
77+
function propose(
78+
spl::MetropolisHastings{<:NamedTuple},
79+
model::DensityModel
80+
)
81+
proposal = _propose(spl.proposal, model)
82+
return Transition(model, proposal)
83+
end
84+
85+
@generated function _propose(
86+
proposals::NamedTuple{names},
87+
model::DensityModel
88+
) where {names}
89+
expr = Expr(:tuple)
90+
map(names) do f
91+
push!(expr.args, Expr(:(=), f, :(propose(proposals.$f, model)) ))
92+
end
93+
return expr
94+
end
95+
96+
function propose(
97+
spl::MetropolisHastings{<:NamedTuple},
98+
model::DensityModel,
99+
params_prev::Transition
100+
)
101+
proposal = _propose(spl.proposal, model, params_prev.params)
102+
return Transition(model, proposal)
103+
end
104+
105+
@generated function _propose(
106+
proposals::NamedTuple{names},
107+
model::DensityModel,
108+
params_prev::NamedTuple
109+
) where {names}
110+
expr = Expr(:tuple)
111+
map(names) do f
112+
push!(expr.args, Expr(:(=), f, :(propose(proposals.$f, model, params_prev.$f)) ))
113+
end
114+
return expr
115+
end
116+
117+
118+
# Evaluate the likelihood of t conditional on t_cond.
119+
function q(
120+
spl::MetropolisHastings{<:AbstractArray},
121+
t::Transition,
122+
t_cond::Transition
123+
)
124+
return sum(map(i -> q(spl.proposal[i], t.params[i], t_cond.params[i]), 1:length(spl.proposal)))
125+
end
126+
127+
function q(
128+
spl::MetropolisHastings{<:Proposal},
129+
t::Transition,
130+
t_cond::Transition
131+
)
132+
return q(spl.proposal, t.params, t_cond.params)
133+
end
134+
135+
function q(
136+
spl::MetropolisHastings{<:NamedTuple},
137+
t::Transition,
138+
t_cond::Transition
139+
)
140+
ks = keys(t.params)
141+
total = 0.0
142+
143+
for k in ks
144+
total += q(spl.proposal[k], t.params[k], t_cond.params[k])
145+
end
146+
147+
return total
19148
end
20149

21150
# Define the first step! function, which is called at the
@@ -28,7 +157,7 @@ function step!(
28157
N::Integer;
29158
kwargs...
30159
)
31-
return Transition(model, spl.init_params)
160+
return propose(spl, model)
32161
end
33162

34163
# Define the other step functions. Returns a Transition containing
@@ -50,7 +179,7 @@ function step!(
50179
q(spl, params_prev, params) - q(spl, params, params_prev)
51180

52181
# Decide whether to return the previous params or the new one.
53-
if log(rand(rng)) < min(α, 0.0)
182+
if log(rand(rng)) < min(0.0, α)
54183
return params
55184
else
56185
return params_prev

src/proposal.jl

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
struct Proposal{T<:ProposalStyle, P}
2+
type :: T
3+
proposal :: P
4+
end
5+
6+
# Random draws
7+
Base.rand(p::Proposal{<:ProposalStyle, <:Distribution}) = rand(p.proposal)
8+
function Base.rand(p::Proposal{<:ProposalStyle, <:AbstractArray})
9+
return map(rand, p.proposal)
10+
end
11+
12+
# Densities
13+
function Distributions.logpdf(p::Proposal{<:ProposalStyle, <:UnivariateDistribution}, v)
14+
return sum(logpdf(p.proposal, v))
15+
end
16+
function Distributions.logpdf(p::Proposal{<:ProposalStyle, <:MultivariateDistribution}, v)
17+
return sum(logpdf(p.proposal, v))
18+
end
19+
function Distributions.logpdf(p::Proposal{<:ProposalStyle, <:MatrixDistribution}, v)
20+
return sum(logpdf(p.proposal, v))
21+
end
22+
function Distributions.logpdf(p::Proposal{<:ProposalStyle, <:AbstractArray}, v)
23+
return sum(map(x -> logpdf(x[1], x[2]), zip(p.proposal, v)))
24+
end
25+
function Distributions.logpdf(p::Proposal{<:ProposalStyle, <:Function}, v)
26+
return logpdf(p.proposal(v), v)
27+
end
28+
29+
###############
30+
# Random Walk #
31+
###############
32+
33+
function propose(
34+
proposal::Proposal{RandomWalk, <:Distribution},
35+
model::DensityModel,
36+
t
37+
)
38+
return t + rand(proposal)
39+
end
40+
41+
function propose(
42+
proposal::Proposal{RandomWalk, <:AbstractArray},
43+
model::DensityModel,
44+
t
45+
)
46+
return t + rand(proposal)
47+
end
48+
49+
function q(
50+
proposal::Proposal{RandomWalk, <:Distribution},
51+
t,
52+
t_cond
53+
)
54+
return sum(logpdf(proposal, t - t_cond))
55+
end
56+
57+
function q(
58+
proposal::Proposal{RandomWalk, <:AbstractArray},
59+
t,
60+
t_cond
61+
)
62+
return sum(logpdf(proposal, t - t_cond))
63+
end
64+
65+
##########
66+
# Static #
67+
##########
68+
69+
propose(p::Proposal{RandomWalk}, m::DensityModel) = propose(Proposal(Static(), p.proposal), m)
70+
function propose(
71+
proposal::Proposal{Static, <:Distribution},
72+
model::DensityModel,
73+
t=nothing
74+
)
75+
return rand(proposal)
76+
end
77+
78+
function propose(
79+
p::Proposal{Static, <:AbstractArray},
80+
model::DensityModel,
81+
t=nothing
82+
)
83+
props = map(x -> rand(x), p.proposal)
84+
return props
85+
end
86+
87+
function q(
88+
proposal::Proposal{Static, <:Distribution},
89+
t,
90+
t_cond
91+
)
92+
return sum(logpdf(proposal, t))
93+
end
94+
95+
function q(
96+
proposal::Proposal{Static, <:AbstractArray},
97+
t,
98+
t_cond
99+
)
100+
return sum(logpdf(proposal, t))
101+
end
102+
103+
############
104+
# Function #
105+
############
106+
107+
function propose(
108+
proposal::Proposal{<:ProposalStyle, <:Function},
109+
model::DensityModel
110+
)
111+
p = proposal.proposal
112+
return rand(proposal.proposal())
113+
end
114+
115+
function propose(
116+
proposal::Proposal{<:ProposalStyle, <:Function},
117+
model::DensityModel,
118+
t
119+
)
120+
p = proposal.proposal
121+
return rand(proposal.proposal(t))
122+
end
123+
124+
function q(
125+
proposal::Proposal{<:ProposalStyle, <:Function},
126+
t,
127+
t_cond
128+
)
129+
p = proposal.proposal
130+
return sum(logpdf.(p(t_cond), t))
131+
end

0 commit comments

Comments
 (0)