Skip to content

Commit b8c26e2

Browse files
committed
Add better method to convert Trandition{NamedTuple} to MCMCChains
1 parent efebf2f commit b8c26e2

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedMH"
22
uuid = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
3-
version = "0.5.6"
3+
version = "0.5.7"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/mcmcchains-connect.jl

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import .MCMCChains: Chains
22

33
# A basic chains constructor that works with the Transition struct we defined.
44
function AbstractMCMC.bundle_samples(
5-
ts::Vector{<:Transition},
5+
ts::Vector{<:Transition{<:AbstractArray}},
66
model::DensityModel,
77
sampler::MHSampler,
88
state,
@@ -28,6 +28,40 @@ function AbstractMCMC.bundle_samples(
2828
return Chains(vals, param_names, (internals = [:lp],))
2929
end
3030

31+
function AbstractMCMC.bundle_samples(
32+
ts::Vector{<:Transition{<:NamedTuple}},
33+
model::DensityModel,
34+
sampler::MHSampler,
35+
state,
36+
chain_type::Type{Chains};
37+
param_names=missing,
38+
kwargs...
39+
)
40+
# Convert to a Vector{NamedTuple} first
41+
nts = AbstractMCMC.bundle_samples(ts, model, sampler, state, Vector{NamedTuple}; param_names=param_names, kwargs...)
42+
43+
# Get all the keys
44+
all_keys = unique(mapreduce(collectkeys, vcat, nts))
45+
46+
# Preallocate array
47+
# vals = []
48+
49+
# Push linearized draws onto array
50+
trygetproperty(thing, key) = key in keys(thing) ? getproperty(thing, key) : missing
51+
vals = map(nt -> [trygetproperty(nt, k) for k in all_keys], nts)
52+
53+
# Check if we received any parameter names.
54+
if ismissing(param_names)
55+
param_names = all_keys
56+
else
57+
# Generate new array to be thread safe.
58+
param_names = Symbol.(param_names)
59+
end
60+
61+
# Bundle everything up and return a Chains struct.
62+
return Chains(vals, param_names, (internals = [:lp],))
63+
end
64+
3165
function AbstractMCMC.bundle_samples(
3266
ts::Vector{<:Vector{<:Transition}},
3367
model::DensityModel,

0 commit comments

Comments
 (0)