Skip to content

Commit 62b638e

Browse files
authored
Merge pull request #21 from TuringLang/csp/emcee
Add affine-invariant ensemble sampler
2 parents 842af40 + 98d3144 commit 62b638e

File tree

5 files changed

+273
-13
lines changed

5 files changed

+273
-13
lines changed

src/AdvancedMH.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ using Requires
88
import Random
99

1010
# Exports
11-
export MetropolisHastings, DensityModel, RWMH, StaticMH, StaticProposal, RandomWalkProposal
11+
export MetropolisHastings, DensityModel, RWMH, StaticMH, StaticProposal,
12+
RandomWalkProposal, Ensemble, StretchProposal
1213

1314
# Reexports
1415
export sample, MCMCThreads, MCMCDistributed
1516

16-
# Abstract type for MH-style samplers.
17-
abstract type Metropolis <: AbstractMCMC.AbstractSampler end
17+
# Abstract type for MH-style samplers. Needs better name?
18+
abstract type MHSampler <: AbstractMCMC.AbstractSampler end
1819

1920
# Define a model type. Stores the log density function and the data to
2021
# evaluate the log density on.
@@ -52,7 +53,7 @@ logdensity(model::DensityModel, t::Transition) = t.lp
5253
function AbstractMCMC.bundle_samples(
5354
rng::Random.AbstractRNG,
5455
model::DensityModel,
55-
s::Metropolis,
56+
s::MHSampler,
5657
N::Integer,
5758
ts::Vector,
5859
chain_type::Type{Any};
@@ -65,7 +66,7 @@ end
6566
function AbstractMCMC.bundle_samples(
6667
rng::Random.AbstractRNG,
6768
model::DensityModel,
68-
s::Metropolis,
69+
s::MHSampler,
6970
N::Integer,
7071
ts::Vector,
7172
chain_type::Type{Vector{NamedTuple}};
@@ -97,5 +98,6 @@ end
9798
# Include inference methods.
9899
include("proposal.jl")
99100
include("mh-core.jl")
101+
include("emcee.jl")
100102

101103
end # module AdvancedMH

src/emcee.jl

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
struct Ensemble{D} <: MHSampler
2+
n_walkers::Int
3+
proposal::D
4+
end
5+
6+
# Define the first step! function, which is called at the
7+
# beginning of sampling. Return the initial parameter used
8+
# to define the sampler.
9+
function AbstractMCMC.step!(
10+
rng::Random.AbstractRNG,
11+
model::DensityModel,
12+
spl::Ensemble,
13+
N::Integer,
14+
::Nothing;
15+
init_params = nothing,
16+
kwargs...,
17+
)
18+
if init_params === nothing
19+
return propose(rng, spl, model)
20+
else
21+
return Transition(model, init_params)
22+
end
23+
end
24+
25+
# Define the other step functions. Returns a Transition containing
26+
# either a new proposal (if accepted) or the previous proposal
27+
# (if not accepted).
28+
function AbstractMCMC.step!(
29+
rng::Random.AbstractRNG,
30+
model::DensityModel,
31+
spl::Ensemble,
32+
::Integer,
33+
params_prev;
34+
kwargs...,
35+
)
36+
# Generate a new proposal. Accept/reject happens at proposal level.
37+
return propose(rng, spl, model, params_prev)
38+
end
39+
40+
#
41+
# Initial proposal
42+
#
43+
function propose(rng::Random.AbstractRNG, spl::Ensemble, model::DensityModel)
44+
# Make the first proposal with a static draw from the prior.
45+
static_prop = StaticProposal(spl.proposal.proposal)
46+
mh_spl = MetropolisHastings(static_prop)
47+
return [propose(rng, mh_spl, model) for _ in 1:spl.n_walkers]
48+
end
49+
50+
#
51+
# Every other proposal
52+
#
53+
function propose(rng::Random.AbstractRNG, spl::Ensemble, model::DensityModel, walkers::Vector{W}) where {W<:Transition}
54+
new_walkers = similar(walkers)
55+
56+
others = 1:(spl.n_walkers - 1)
57+
for i in 1:spl.n_walkers
58+
walker = walkers[i]
59+
idx = mod1(i + rand(rng, others), spl.n_walkers)
60+
other_walker = walkers[idx]
61+
new_walkers[i] = move(rng, spl, model, walker, other_walker)
62+
end
63+
64+
return new_walkers
65+
end
66+
67+
68+
#####################################
69+
# Basic stretch move implementation #
70+
#####################################
71+
struct StretchProposal{P, F<:AbstractFloat} <: Proposal{P}
72+
proposal :: P
73+
stretch_length::F
74+
end
75+
76+
StretchProposal(p) = StretchProposal(p, 2.0)
77+
78+
function move(
79+
rng::Random.AbstractRNG,
80+
spl::Ensemble{<:StretchProposal},
81+
model::DensityModel,
82+
walker::Transition,
83+
other_walker::Transition,
84+
)
85+
# Calculate intermediate values
86+
proposal = spl.proposal
87+
n = length(walker.params)
88+
a = proposal.stretch_length
89+
z = ((a - 1) * rand(rng) + 1)^2 / a
90+
alphamult = (n - 1) * log(z)
91+
92+
# Make new parameters
93+
y = @. walker.params + z * (other_walker.params - walker.params)
94+
95+
# Construct a new walker
96+
new_walker = Transition(model, y)
97+
98+
# Calculate accept/reject value.
99+
alpha = alphamult + new_walker.lp - walker.lp
100+
101+
if -Random.randexp(rng) <= alpha
102+
return new_walker
103+
else
104+
return walker
105+
end
106+
end
107+
108+
#########################
109+
# Elliptical slice step #
110+
# #########################
111+
112+
# struct EllipticalSlice{E} <: ProposalStyle
113+
# ellipse::E
114+
# end
115+
116+
# function move(
117+
# # spl::Ensemble,
118+
# spl::Ensemble{Proposal{T,P}},
119+
# model::DensityModel,
120+
# walker::Transition,
121+
# other_walker::Transition,
122+
# ) where {T<:EllipticalSlice,P}
123+
# # Calculate intermediate values
124+
# proposal = spl.proposal
125+
# n = length(walker.params)
126+
# nu = rand(proposal.type.ellipse)
127+
128+
# u = rand()
129+
# y = walker.lp - Random.randexp()
130+
131+
# theta = 2 * π * rand()
132+
133+
# theta_min = theta - 2.0*π
134+
# theta_max = theta
135+
136+
# f = walker.params
137+
# while true
138+
# stheta, ctheta = sincos(theta)
139+
140+
# f_prime = f .* ctheta + nu .* stheta
141+
142+
# new_walker = Transition(model, f_prime)
143+
144+
# if new_walker.lp > y
145+
# return new_walker
146+
# else
147+
# if theta < 0
148+
# theta_min = theta
149+
# else
150+
# theta_max = theta
151+
# end
152+
153+
# theta = theta_min + (theta_max - theta_min) * rand()
154+
# end
155+
# end
156+
# end
157+
158+
#####################
159+
# Slice and stretch #
160+
#####################
161+
# struct EllipticalSliceStretch{E, S<:Stretch} <: ProposalStyle
162+
# ellipse::E
163+
# stretch::S
164+
# end
165+
166+
# EllipticalSliceStretch(e) = EllipticalSliceStretch(e, Stretch(2.0))
167+
168+
# function move(
169+
# # spl::Ensemble,
170+
# spl::Ensemble{Proposal{T,P}},
171+
# model::DensityModel,
172+
# walker::Transition,
173+
# other_walker::Transition,
174+
# ) where {T<:EllipticalSliceStretch,P}
175+
# # Calculate intermediate values
176+
# proposal = spl.proposal
177+
# n = length(walker.params)
178+
# nu = rand(proposal.type.ellipse)
179+
180+
# # Calculate stretch step first
181+
# subspl = Ensemble(spl.n_walkers, Proposal(proposal.type.stretch, proposal.proposal))
182+
# walker = move(subspl, model, walker, other_walker)
183+
184+
# u = rand()
185+
# y = walker.lp - Random.randexp()
186+
187+
# theta = 2 * π * rand()
188+
189+
# theta_min = theta - 2.0*π
190+
# theta_max = theta
191+
192+
# f = walker.params
193+
194+
# i = 0
195+
# while true
196+
# i += 1
197+
198+
# stheta, ctheta = sincos(theta)
199+
200+
# f_prime = f .* ctheta + nu .* stheta
201+
202+
# new_walker = Transition(model, f_prime)
203+
204+
# # @info "Slice step" i f f_prime y new_walker.lp theta theta_max theta_min
205+
206+
# if new_walker.lp > y
207+
# return new_walker
208+
# else
209+
# if theta < 0
210+
# theta_min = theta
211+
# else
212+
# theta_max = theta
213+
# end
214+
215+
# theta = theta_min + (theta_max - theta_min) * rand()
216+
# end
217+
# end
218+
# end

src/mcmcchains-connect.jl

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ import .MCMCChains: Chains
44
function AbstractMCMC.bundle_samples(
55
rng::Random.AbstractRNG,
66
model::DensityModel,
7-
s::Metropolis,
7+
s::MHSampler,
88
N::Integer,
9-
ts::Vector,
9+
ts,
1010
chain_type::Type{Chains};
1111
param_names=missing,
1212
kwargs...
@@ -16,7 +16,47 @@ function AbstractMCMC.bundle_samples(
1616

1717
# Check if we received any parameter names.
1818
if ismissing(param_names)
19-
param_names = ["Parameter $i" for i in 1:length(s.init_params)]
19+
param_names = ["param_$i" for i in 1:length(s.init_params)]
20+
else
21+
# Deepcopy to be thread safe.
22+
param_names = deepcopy(param_names)
23+
end
24+
25+
# Add the log density field to the parameter names.
26+
push!(param_names, "lp")
27+
28+
# Bundle everything up and return a Chains struct.
29+
return Chains(vals, param_names, (internals=["lp"],))
30+
end
31+
32+
function AbstractMCMC.bundle_samples(
33+
rng::Random.AbstractRNG,
34+
model::DensityModel,
35+
s::Ensemble,
36+
N::Integer,
37+
ts::Vector,
38+
chain_type::Type{Chains};
39+
param_names=missing,
40+
kwargs...
41+
)
42+
# Preallocate return array
43+
# NOTE: requires constant dimensionality.
44+
n_params = length(ts[1][1].params)
45+
vals = Array{Float64, 3}(undef, N, n_params + 1, s.n_walkers) # add 1 parameter for lp
46+
47+
for n in 1:N
48+
for i in 1:s.n_walkers
49+
walker = ts[n][i]
50+
for j in 1:n_params
51+
vals[n, j, i] = walker.params[j]
52+
end
53+
vals[n, n_params + 1, i] = walker.lp
54+
end
55+
end
56+
57+
# Check if we received any parameter names.
58+
if ismissing(param_names)
59+
param_names = ["param_$i" for i in 1:length(ts[1][1].params)]
2060
else
2161
# Deepcopy to be thread safe.
2262
param_names = deepcopy(param_names)

src/mh-core.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ used if `chain_type=Chains`.
4141
types are `chain_type=Chains` if `MCMCChains` is imported, or
4242
`chain_type=StructArray` if `StructArrays` is imported.
4343
"""
44-
struct MetropolisHastings{D} <: Metropolis
44+
struct MetropolisHastings{D} <: MHSampler
4545
proposal::D
4646
end
4747

@@ -196,14 +196,14 @@ function AbstractMCMC.step!(
196196
model::DensityModel,
197197
spl::MetropolisHastings,
198198
::Integer,
199-
params_prev::Transition;
199+
params_prev;
200200
kwargs...
201201
)
202202
# Generate a new proposal.
203203
params = propose(rng, spl, model, params_prev)
204204

205205
# Calculate the log acceptance probability.
206-
logα = logdensity(model, params) - logdensity(model, params_prev) +
206+
logα = logdensity(model, params) - logdensity(model, params_prev) +
207207
q(spl, params_prev, params) - q(spl, params, params_prev)
208208

209209
# Decide whether to return the previous params or the new one.
@@ -212,4 +212,4 @@ function AbstractMCMC.step!(
212212
else
213213
return params_prev
214214
end
215-
end
215+
end

src/structarray-connect.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import .StructArrays: StructArray
44
function AbstractMCMC.bundle_samples(
55
rng::Random.AbstractRNG,
66
model::DensityModel,
7-
s::Metropolis,
7+
s::MHSampler,
88
N::Integer,
99
ts::Vector,
1010
chain_type::Type{StructArray};

0 commit comments

Comments
 (0)