@@ -2,7 +2,7 @@ import .MCMCChains: Chains
2
2
3
3
# A basic chains constructor that works with the Transition struct we defined.
4
4
function AbstractMCMC. bundle_samples (
5
- ts:: Vector{<:Transition} ,
5
+ ts:: Vector{<:Transition{<:AbstractArray} } ,
6
6
model:: DensityModel ,
7
7
sampler:: MHSampler ,
8
8
state,
@@ -28,6 +28,40 @@ function AbstractMCMC.bundle_samples(
28
28
return Chains (vals, param_names, (internals = [:lp ],))
29
29
end
30
30
31
+ function AbstractMCMC. bundle_samples (
32
+ ts:: Vector{<:Transition{<:NamedTuple}} ,
33
+ model:: DensityModel ,
34
+ sampler:: MHSampler ,
35
+ state,
36
+ chain_type:: Type{Chains} ;
37
+ param_names= missing ,
38
+ kwargs...
39
+ )
40
+ # Convert to a Vector{NamedTuple} first
41
+ nts = AbstractMCMC. bundle_samples (ts, model, sampler, state, Vector{NamedTuple}; param_names= param_names, kwargs... )
42
+
43
+ # Get all the keys
44
+ all_keys = unique (mapreduce (collect∘ keys, vcat, nts))
45
+
46
+ # Preallocate array
47
+ # vals = []
48
+
49
+ # Push linearized draws onto array
50
+ trygetproperty (thing, key) = key in keys (thing) ? getproperty (thing, key) : missing
51
+ vals = map (nt -> [trygetproperty (nt, k) for k in all_keys], nts)
52
+
53
+ # Check if we received any parameter names.
54
+ if ismissing (param_names)
55
+ param_names = all_keys
56
+ else
57
+ # Generate new array to be thread safe.
58
+ param_names = Symbol .(param_names)
59
+ end
60
+
61
+ # Bundle everything up and return a Chains struct.
62
+ return Chains (vals, param_names, (internals = [:lp ],))
63
+ end
64
+
31
65
function AbstractMCMC. bundle_samples (
32
66
ts:: Vector{<:Vector{<:Transition}} ,
33
67
model:: DensityModel ,
0 commit comments