Skip to content

Commit 570e780

Browse files
authored
Merge pull request #49 from TuringLang/csp/ntfix
Add better method to convert Trandition{NamedTuple} to MCMCChains
2 parents efebf2f + 482e9aa commit 570e780

File tree

3 files changed

+50
-2
lines changed

3 files changed

+50
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.5.6"
3+
version = "0.5.7"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/mcmcchains-connect.jl

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import .MCMCChains: Chains
22

33
# A basic chains constructor that works with the Transition struct we defined.
44
function AbstractMCMC.bundle_samples(
5-
ts::Vector{<:Transition},
5+
ts::Vector{<:Transition{<:AbstractArray}},
66
model::DensityModel,
77
sampler::MHSampler,
88
state,
@@ -28,6 +28,40 @@ function AbstractMCMC.bundle_samples(
2828
return Chains(vals, param_names, (internals = [:lp],))
2929
end
3030

31+
function AbstractMCMC.bundle_samples(
32+
ts::Vector{<:Transition{<:NamedTuple}},
33+
model::DensityModel,
34+
sampler::MHSampler,
35+
state,
36+
chain_type::Type{Chains};
37+
param_names=missing,
38+
kwargs...
39+
)
40+
# Convert to a Vector{NamedTuple} first
41+
nts = AbstractMCMC.bundle_samples(ts, model, sampler, state, Vector{NamedTuple}; param_names=param_names, kwargs...)
42+
43+
# Get all the keys
44+
all_keys = unique(mapreduce(collectkeys, vcat, nts))
45+
46+
# Preallocate array
47+
# vals = []
48+
49+
# Push linearized draws onto array
50+
trygetproperty(thing, key) = key in keys(thing) ? getproperty(thing, key) : missing
51+
vals = map(nt -> [trygetproperty(nt, k) for k in all_keys], nts)
52+
53+
# Check if we received any parameter names.
54+
if ismissing(param_names)
55+
param_names = all_keys
56+
else
57+
# Generate new array to be thread safe.
58+
param_names = Symbol.(param_names)
59+
end
60+
61+
# Bundle everything up and return a Chains struct.
62+
return Chains(vals, param_names, (internals = [:lp],))
63+
end
64+
3165
function AbstractMCMC.bundle_samples(
3266
ts::Vector{<:Vector{<:Transition}},
3367
model::DensityModel,

test/runtests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,20 @@ include("util.jl")
7373
end
7474
end
7575

76+
@testset "MCMCChains" begin
77+
spl1 = StaticMH([Normal(0,1), Normal(0, 1)])
78+
spl2 = MetropolisHastings((μ = StaticProposal(Normal(0,1)), σ = StaticProposal(Normal(0, 1))))
79+
80+
chain1 = sample(model, spl1, 10_000; param_names=["μ", "σ"], chain_type=Chains)
81+
chain2 = sample(model, spl2, 10_000; chain_type=Chains)
82+
83+
@test mean(chain1["μ"]) 0.0 atol=0.1
84+
@test mean(chain1["σ"]) 1.0 atol=0.1
85+
86+
@test mean(chain2["μ"]) 0.0 atol=0.1
87+
@test mean(chain2["σ"]) 1.0 atol=0.1
88+
end
89+
7690
@testset "Proposal styles" begin
7791
m1 = DensityModel(x -> logpdf(Normal(x,1), 1.0))
7892
m2 = DensityModel(x -> logpdf(Normal(x[1], x[2]), 1.0))

0 commit comments

Comments
 (0)