Skip to content

Commit f59287a

Browse files
committed
Addressing comments
1 parent 815fcd5 commit f59287a

File tree

4 files changed

+23
-24
lines changed

4 files changed

+23
-24
lines changed

src/AdvancedMH.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@ using Requires
88
import Random
99

1010
# Exports
11-
export MetropolisHastings, DensityModel, RWMH, StaticMH, StaticProposal, RandomWalkProposal
11+
export MetropolisHastings, DensityModel, RWMH, StaticMH, StaticProposal,
12+
RandomWalkProposal, Ensemble, StretchProposal
1213

1314
# Reexports
1415
export sample, MCMCThreads, MCMCDistributed

src/emcee.jl

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,16 @@ function propose(rng::Random.AbstractRNG, spl::Ensemble, model::DensityModel)
4747
return [propose(rng, mh_spl, model) for _ in 1:spl.n_walkers]
4848
end
4949

50-
51-
5250
#
5351
# Every other proposal
5452
#
5553
function propose(rng::Random.AbstractRNG, spl::Ensemble, model::DensityModel, walkers::Vector{W}) where {W<:Transition}
56-
new_walkers = Vector{W}(undef, spl.n_walkers)
57-
interval = 1:spl.n_walkers
58-
54+
new_walkers = similar(walkers)
5955

60-
nwalkers = spl.n_walkers
61-
others = 1:(nwalkers - 1)
62-
for i in interval
56+
others = 1:(spl.n_walkers - 1)
57+
for i in 1:spl.n_walkers
6358
walker = walkers[i]
64-
idx = mod1(i + rand(rng, others), nwalkers)
59+
idx = mod1(i + rand(rng, others), spl.n_walkers)
6560
other_walker = walkers[idx]
6661
new_walkers[i] = move(rng, spl, model, walker, other_walker)
6762
end
@@ -95,7 +90,7 @@ function move(
9590
alphamult = (n - 1) * log(z)
9691

9792
# Make new parameters
98-
y = walker.params + z .* (other_walker.params - walker.params)
93+
y = @. walker.params + z * (other_walker.params - walker.params)
9994

10095
# Construct a new walker
10196
new_walker = Transition(model, y)

src/mcmcchains-connect.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,20 @@ function AbstractMCMC.bundle_samples(
3939
param_names=missing,
4040
kwargs...
4141
)
42-
# return ts
43-
vals = mapreduce(
44-
t -> map(i -> vcat(ts[t][i].params,
45-
ts[t][i].lp, t, i),
46-
1:length(ts[t])),
47-
vcat,
48-
1:length(ts))
49-
50-
vals = Array(reduce(hcat, vals)')
42+
# Preallocate return array
43+
# NOTE: requires constant dimensionality.
44+
n_params = length(ts[1][1].params)
45+
vals = Array{Float64, 3}(undef, N, n_params + 1, s.n_walkers) # add 1 parameter for lp
5146

52-
# return vals
47+
for n in 1:N
48+
for i in 1:s.n_walkers
49+
walker = ts[n][i]
50+
for j in 1:n_params
51+
vals[n, j, i] = walker.params[j]
52+
end
53+
vals[n, n_params + 1, i] = walker.lp
54+
end
55+
end
5356

5457
# Check if we received any parameter names.
5558
if ismissing(param_names)
@@ -60,8 +63,8 @@ function AbstractMCMC.bundle_samples(
6063
end
6164

6265
# Add the log density field to the parameter names.
63-
push!(param_names, "lp", "iteration", "walker")
66+
push!(param_names, "lp")
6467

6568
# Bundle everything up and return a Chains struct.
66-
return Chains(vals, param_names, (internals=["lp", "iteration", "walker"],))
69+
return Chains(vals, param_names, (internals=["lp"],))
6770
end

src/structarray-connect.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import .StructArrays: StructArray
44
function AbstractMCMC.bundle_samples(
55
rng::Random.AbstractRNG,
66
model::DensityModel,
7-
s::Metropolis,
7+
s::MHSampler,
88
N::Integer,
99
ts::Vector,
1010
chain_type::Type{StructArray};

0 commit comments

Comments
 (0)