Skip to content

Commit f9ed562

Browse files
committed
Make Gibbs constructor more flexible
1 parent 519ff02 commit f9ed562

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

src/mcmc/gibbs.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,21 +320,30 @@ struct Gibbs{V,A} <: InferenceAlgorithm
320320
end
321321
end
322322

323+
to_varname(vn::VarName) = vn
324+
to_varname(s::Symbol) = VarName{s}()
325+
# Any other value is assumed to be an iterable.
326+
to_varname(t) = map(to_varname, collect(t))
327+
323328
# NamedTuple
324329
Gibbs(; algs...) = Gibbs(NamedTuple(algs))
325330
function Gibbs(algs::NamedTuple)
326331
return Gibbs(
327-
map(s -> VarName{s}(), keys(algs)),
328-
map(wrap_algorithm_maybe drop_space, values(algs)),
332+
map(to_varname, keys(algs)), map(wrap_algorithm_maybe drop_space, values(algs))
329333
)
330334
end
331335

332336
# AbstractDict
333337
function Gibbs(algs::AbstractDict)
334-
return Gibbs(collect(keys(algs)), map(wrap_algorithm_maybe drop_space, values(algs)))
338+
return Gibbs(
339+
map(to_varname, collect(keys(algs))),
340+
map(wrap_algorithm_maybe drop_space, values(algs)),
341+
)
335342
end
336343
function Gibbs(algs::Pair...)
337-
return Gibbs(map(first, algs), map(wrap_algorithm_maybe drop_space, map(last, algs)))
344+
return Gibbs(
345+
map(to_varname first, algs), map(wrap_algorithm_maybe drop_space last, algs)
346+
)
338347
end
339348

340349
# The below two constructors only provide backwards compatibility with the constructor of
@@ -383,6 +392,7 @@ end
383392
_maybevec(x) = vec(x) # assume it's iterable
384393
_maybevec(x::Tuple) = [x...]
385394
_maybevec(x::VarName) = [x]
395+
_maybevec(x::Symbol) = [x]
386396

387397
varinfo(state::GibbsState) = state.vi
388398

test/mcmc/gibbs.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,15 +284,15 @@ end
284284
N = 10
285285
# Two variables being sampled by one sampler.
286286
s1 = Gibbs((@varname(s), @varname(m)) => HMC(0.1, 5; adtype=adbackend))
287-
s2 = Gibbs((@varname(s), @varname(m)) => PG(10))
287+
s2 = Gibbs((@varname(s), :m) => PG(10))
288288
# One variable per sampler, using the keyword arg interface.
289289
s3 = Gibbs((; s=PG(3), m=HMC(0.4, 8; adtype=adbackend)))
290290
# As above but using a Dict of VarNames.
291291
s4 = Gibbs(Dict(@varname(s) => PG(3), @varname(m) => HMC(0.4, 8; adtype=adbackend)))
292-
# As above but different samplers.
292+
# As above but different samplers and using kwargs.
293293
s5 = Gibbs(; s=CSMC(3), m=HMC(0.4, 8; adtype=adbackend))
294294
s6 = Gibbs(; s=HMC(0.1, 5; adtype=adbackend), m=ESS())
295-
s7 = Gibbs((@varname(s), @varname(m)) => PG(10))
295+
s7 = Gibbs(Dict((:s, @varname(m)) => PG(10)))
296296
# Multiple instnaces of the same sampler. This implements running, in this case,
297297
# 3 steps of HMC on m and 2 steps of PG on m in every iteration of Gibbs.
298298
s8 = begin

0 commit comments

Comments
 (0)