Skip to content

Commit 947357e

Browse files
authored
Merge pull request #33 from devmotion/abstractmcmc2
Update to AbstractMCMC 2
2 parents 10766d4 + 82cd0a9 commit 947357e

File tree

6 files changed

+78
-87
lines changed

6 files changed

+78
-87
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.5.1"
3+
version = "0.5.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -9,9 +9,9 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
99
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1010

1111
[compat]
12-
AbstractMCMC = "1"
12+
AbstractMCMC = "2"
1313
Distributions = "0.20, 0.21, 0.22, 0.23"
14-
Requires = "1.0"
14+
Requires = "1"
1515
julia = "1"
1616

1717
[extras]

src/AdvancedMH.jl

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ struct DensityModel{F<:Function} <: AbstractMCMC.AbstractModel
3535
logdensity :: F
3636
end
3737

38-
# Create a very basic Transition type, only stores the
38+
# Create a very basic Transition type, only stores the
3939
# parameter draws and the log probability of the draw.
4040
struct Transition{T<:Union{Vector, Real, NamedTuple}, L<:Real}
4141
params :: T
@@ -51,25 +51,11 @@ logdensity(model::DensityModel, t::Transition) = t.lp
5151

5252
# A basic chains constructor that works with the Transition struct we defined.
5353
function AbstractMCMC.bundle_samples(
54-
rng::Random.AbstractRNG,
55-
model::DensityModel,
56-
s::MHSampler,
57-
N::Integer,
58-
ts::Vector,
59-
chain_type::Type{Any};
60-
param_names=missing,
61-
kwargs...
62-
)
63-
return ts
64-
end
65-
66-
function AbstractMCMC.bundle_samples(
67-
rng::Random.AbstractRNG,
68-
model::DensityModel,
69-
s::MHSampler,
70-
N::Integer,
71-
ts::Vector,
72-
chain_type::Type{Vector{NamedTuple}};
54+
ts::Vector{<:Transition},
55+
model::DensityModel,
56+
sampler::MHSampler,
57+
state,
58+
chain_type::Type{Vector{NamedTuple}};
7359
param_names=missing,
7460
kwargs...
7561
)

src/emcee.jl

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,43 +3,44 @@ struct Ensemble{D} <: MHSampler
33
proposal::D
44
end
55

6-
# Define the first step! function, which is called at the
7-
# beginning of sampling. Return the initial parameter used
8-
# to define the sampler.
9-
function AbstractMCMC.step!(
6+
# Define the first sampling step.
7+
# Return a 2-tuple consisting of the initial sample and the initial state.
8+
# In this case they are identical.
9+
function AbstractMCMC.step(
1010
rng::Random.AbstractRNG,
1111
model::DensityModel,
12-
spl::Ensemble,
13-
N::Integer,
14-
::Nothing;
12+
spl::Ensemble;
1513
init_params = nothing,
1614
kwargs...,
1715
)
1816
if init_params === nothing
19-
return propose(rng, spl, model)
17+
transitions = propose(rng, spl, model)
2018
else
21-
return Transition(model, init_params)
19+
transitions = [Transition(model, x) for x in init_params]
2220
end
21+
22+
return transitions, transitions
2323
end
2424

25-
# Define the other step functions. Returns a Transition containing
26-
# either a new proposal (if accepted) or the previous proposal
27-
# (if not accepted).
28-
function AbstractMCMC.step!(
25+
# Define the other sampling steps.
26+
# Return a 2-tuple consisting of the next sample and the the next state.
27+
# In this case they are identical, and for each walker they are either a new proposal
28+
# (if accepted) or the previous proposal (if not accepted).
29+
function AbstractMCMC.step(
2930
rng::Random.AbstractRNG,
3031
model::DensityModel,
3132
spl::Ensemble,
32-
::Integer,
33-
params_prev;
33+
params_prev::Vector{<:Transition};
3434
kwargs...,
3535
)
3636
# Generate a new proposal. Accept/reject happens at proposal level.
37-
return propose(rng, spl, model, params_prev)
37+
transitions = propose(rng, spl, model, params_prev)
38+
return transitions, transitions
3839
end
3940

4041
#
4142
# Initial proposal
42-
#
43+
#
4344
function propose(rng::Random.AbstractRNG, spl::Ensemble, model::DensityModel)
4445
# Make the first proposal with a static draw from the prior.
4546
static_prop = StaticProposal(spl.proposal.proposal)
@@ -49,8 +50,13 @@ end
4950

5051
#
5152
# Every other proposal
52-
#
53-
function propose(rng::Random.AbstractRNG, spl::Ensemble, model::DensityModel, walkers::Vector{W}) where {W<:Transition}
53+
#
54+
function propose(
55+
rng::Random.AbstractRNG,
56+
spl::Ensemble,
57+
model::DensityModel,
58+
walkers::Vector{<:Transition},
59+
)
5460
new_walkers = similar(walkers)
5561

5662
others = 1:(spl.n_walkers - 1)
@@ -64,7 +70,6 @@ function propose(rng::Random.AbstractRNG, spl::Ensemble, model::DensityModel, wa
6470
return new_walkers
6571
end
6672

67-
6873
#####################################
6974
# Basic stretch move implementation #
7075
#####################################

src/mcmcchains-connect.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@ import .MCMCChains: Chains
22

33
# A basic chains constructor that works with the Transition struct we defined.
44
function AbstractMCMC.bundle_samples(
5-
rng::Random.AbstractRNG,
6-
model::DensityModel,
7-
s::MHSampler,
8-
N::Integer,
9-
ts,
10-
chain_type::Type{Chains};
5+
ts::Vector{<:Transition},
6+
model::DensityModel,
7+
sampler::MHSampler,
8+
state,
9+
chain_type::Type{Chains};
1110
param_names=missing,
1211
kwargs...
1312
)
@@ -16,7 +15,7 @@ function AbstractMCMC.bundle_samples(
1615

1716
# Check if we received any parameter names.
1817
if ismissing(param_names)
19-
param_names = [Symbol(:param_, i) for i in 1:length(s.init_params)]
18+
param_names = [Symbol(:param_, i) for i in 1:length(keys(ts[1].params))]
2019
else
2120
# Generate new array to be thread safe.
2221
param_names = Symbol.(param_names)
@@ -30,22 +29,23 @@ function AbstractMCMC.bundle_samples(
3029
end
3130

3231
function AbstractMCMC.bundle_samples(
33-
rng::Random.AbstractRNG,
34-
model::DensityModel,
35-
s::Ensemble,
36-
N::Integer,
37-
ts::Vector,
38-
chain_type::Type{Chains};
32+
ts::Vector{<:Vector{<:Transition}},
33+
model::DensityModel,
34+
sampler::Ensemble,
35+
state,
36+
chain_type::Type{Chains};
3937
param_names=missing,
4038
kwargs...
4139
)
4240
# Preallocate return array
4341
# NOTE: requires constant dimensionality.
4442
n_params = length(ts[1][1].params)
45-
vals = Array{Float64, 3}(undef, N, n_params + 1, s.n_walkers) # add 1 parameter for lp
43+
nsamples = length(ts)
44+
# add 1 parameter for lp
45+
vals = Array{Float64, 3}(undef, nsamples, n_params + 1, sampler.n_walkers)
4646

47-
for n in 1:N
48-
for i in 1:s.n_walkers
47+
for n in 1:nsamples
48+
for i in 1:sampler.n_walkers
4949
walker = ts[n][i]
5050
for j in 1:n_params
5151
vals[n, j, i] = walker.params[j]
@@ -56,7 +56,7 @@ function AbstractMCMC.bundle_samples(
5656

5757
# Check if we received any parameter names.
5858
if ismissing(param_names)
59-
param_names = [Symbol(:param_, i) for i in 1:length(ts[1][1].params)]
59+
param_names = [Symbol(:param_, i) for i in 1:length(keys(ts[1][1].params))]
6060
else
6161
# Generate new array to be thread safe.
6262
param_names = Symbol.(param_names)

src/mh-core.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -169,34 +169,34 @@ function q(
169169
end
170170
end
171171

172-
# Define the first step! function, which is called at the
173-
# beginning of sampling. Return the initial parameter used
174-
# to define the sampler.
175-
function AbstractMCMC.step!(
172+
# Define the first sampling step.
173+
# Return a 2-tuple consisting of the initial sample and the initial state.
174+
# In this case they are identical.
175+
function AbstractMCMC.step(
176176
rng::Random.AbstractRNG,
177177
model::DensityModel,
178-
spl::MetropolisHastings,
179-
N::Integer,
180-
::Nothing;
178+
spl::MetropolisHastings;
181179
init_params=nothing,
182180
kwargs...
183181
)
184182
if init_params === nothing
185-
return propose(rng, spl, model)
183+
transition = propose(rng, spl, model)
186184
else
187-
return Transition(model, init_params)
185+
transition = Transition(model, init_params)
188186
end
187+
188+
return transition, transition
189189
end
190190

191-
# Define the other step functions. Returns a Transition containing
192-
# either a new proposal (if accepted) or the previous proposal
193-
# (if not accepted).
194-
function AbstractMCMC.step!(
191+
# Define the other sampling steps.
192+
# Return a 2-tuple consisting of the next sample and the the next state.
193+
# In this case they are identical, and either a new proposal (if accepted)
194+
# or the previous proposal (if not accepted).
195+
function AbstractMCMC.step(
195196
rng::Random.AbstractRNG,
196197
model::DensityModel,
197198
spl::MetropolisHastings,
198-
::Integer,
199-
params_prev;
199+
params_prev::Transition;
200200
kwargs...
201201
)
202202
# Generate a new proposal.
@@ -208,8 +208,8 @@ function AbstractMCMC.step!(
208208

209209
# Decide whether to return the previous params or the new one.
210210
if -Random.randexp(rng) < logα
211-
return params
211+
return params, params
212212
else
213-
return params_prev
213+
return params_prev, params_prev
214214
end
215215
end

src/structarray-connect.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,18 @@ import .StructArrays: StructArray
22

33
# A basic chains constructor that works with the Transition struct we defined.
44
function AbstractMCMC.bundle_samples(
5-
rng::Random.AbstractRNG,
6-
model::DensityModel,
7-
s::MHSampler,
8-
N::Integer,
9-
ts::Vector,
10-
chain_type::Type{StructArray};
11-
param_names=missing,
5+
ts,
6+
model::DensityModel,
7+
sampler::MHSampler,
8+
state,
9+
chain_type::Type{StructArray};
1210
kwargs...
1311
)
14-
samples = AbstractMCMC.bundle_samples(rng, model, s, N, ts, Vector{NamedTuple};
15-
param_names=param_names, kwargs...)
12+
samples = AbstractMCMC.bundle_samples(
13+
ts, model, sampler, state, Vector{NamedTuple};
14+
kwargs...
15+
)
1616
return StructArray(samples)
1717
end
1818

19-
AbstractMCMC.chainscat(c::StructArray, cs::StructArray...) = vcat(c, cs...)
19+
AbstractMCMC.chainscat(c::StructArray, cs::StructArray...) = vcat(c, cs...)

0 commit comments

Comments
 (0)