Skip to content

Commit 10accc2

Browse files
committed
Use dispatch instead of if-statement
1 parent 3a9a758 commit 10accc2

File tree

1 file changed

+32
-23
lines changed

1 file changed

+32
-23
lines changed

src/AdvancedMH.jl

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -59,31 +59,40 @@ function AbstractMCMC.bundle_samples(
5959
param_names=missing,
6060
kwargs...
6161
) where {T, L}
62-
# If the element type of ts is NamedTuples, just use those names:
63-
if T <: NamedTuple
64-
# Extract NamedTuples
65-
nts = map(x -> merge(x.params, (lp=x.lp,)), ts)
66-
67-
# Return em'
68-
return nts
62+
# Check if we received any parameter names.
63+
if ismissing(param_names)
64+
param_names = ["param_$i" for i in 1:length(keys(ts[1].params))]
6965
else
70-
# Otherwise, default to heuristics to infer parameter names.
71-
# Check if we received any parameter names.
72-
if ismissing(param_names)
73-
param_names = ["param_$i" for i in 1:length(keys(ts[1].params))]
74-
else
75-
# Deepcopy to be thread safe.
76-
param_names = deepcopy(param_names)
77-
end
78-
79-
push!(param_names, "lp")
80-
81-
# Turn all the transitions into a vector-of-NamedTuple.
82-
ks = tuple(Symbol.(param_names)...)
83-
nts = [NamedTuple{ks}(tuple(t.params..., t.lp)) for t in ts]
84-
85-
return nts
66+
# Deepcopy to be thread safe.
67+
param_names = deepcopy(param_names)
8668
end
69+
70+
push!(param_names, "lp")
71+
72+
# Turn all the transitions into a vector-of-NamedTuple.
73+
ks = tuple(Symbol.(param_names)...)
74+
nts = [NamedTuple{ks}(tuple(t.params..., t.lp)) for t in ts]
75+
76+
return nts
77+
end
78+
79+
function AbstractMCMC.bundle_samples(
80+
ts::Vector{Transition{T, L}},
81+
model::DensityModel,
82+
sampler::MHSampler,
83+
state,
84+
chain_type::Type{Vector{NamedTuple}};
85+
param_names=missing,
86+
kwargs...
87+
) where {T<:NamedTuple, L}
88+
# If the element type of ts is NamedTuples, just use the names in the
89+
# struct.
90+
91+
# Extract NamedTuples
92+
nts = map(x -> merge(x.params, (lp=x.lp,)), ts)
93+
94+
# Return em'
95+
return nts
8796
end
8897

8998
function __init__()

0 commit comments

Comments
 (0)