Skip to content

Commit 97cf02b

Browse files
committed
More fixing of DPPL v0.35 stuff
1 parent 58cef90 commit 97cf02b

File tree

3 files changed

+20
-14
lines changed

3 files changed

+20
-14
lines changed

src/mcmc/Inference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using DynamicPPL:
88
# TODO(mhauru) all_varnames_grouped_by_symbol isn't exported by DPPL. Either export it
99
# or use something else.
1010
all_varnames_grouped_by_symbol,
11+
syms,
1112
islinked,
1213
setindex!!,
1314
push!!,

src/mcmc/emcee.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ function AbstractMCMC.step(
6868
vis[1],
6969
map(vis) do vi
7070
vi = DynamicPPL.link!!(vi, model)
71-
AMH.Transition(vi[spl], getlogp(vi), false)
71+
AMH.Transition(vi[:], getlogp(vi), false)
7272
end,
7373
)
7474

src/mcmc/ess.jl

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,20 @@ function DynamicPPL.initialstep(
2727
rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs...
2828
)
2929
# Sanity check
30-
vns = _getvns(vi, spl)
31-
length(vns) == 1 ||
32-
error("[ESS] does only support one variable ($(length(vns)) variables specified)")
33-
for vn in only(vns)
30+
# TODO(mhauru) What is the point of the first check? Why is it relevant that if there
31+
# are multiple variables they are all under the same symbol?
32+
vn_syms = syms(vi)
33+
if length(vn_syms) != 1
34+
msg = """
35+
ESS only supports one variable symbol ($(length(vn_syms)) variables specified)\
36+
"""
37+
error(msg)
38+
end
39+
for vn in keys(vi)
3440
dist = getdist(vi, vn)
3541
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
36-
error("[ESS] only supports Gaussian prior distributions")
42+
error("ESS only supports Gaussian prior distributions")
3743
end
38-
3944
return Transition(model, vi), vi
4045
end
4146

@@ -61,7 +66,7 @@ function AbstractMCMC.step(
6166
)
6267

6368
# update sample and log-likelihood
64-
vi = setindex!!(vi, sample, spl)
69+
vi = DynamicPPL.unflatten(vi, sample)
6570
vi = setlogp!!(vi, state.loglikelihood)
6671

6772
return Transition(model, vi), vi
@@ -77,8 +82,8 @@ struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T}
7782
function ESSPrior{M,S,V}(
7883
model::M, sampler::S, varinfo::V
7984
) where {M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo}
80-
vns = _getvns(varinfo, sampler)
81-
μ = mapreduce(vcat, vns[1]) do vn
85+
vns = keys(varinfo)
86+
μ = mapreduce(vcat, vns) do vn
8287
dist = getdist(varinfo, vn)
8388
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
8489
error("[ESS] only supports Gaussian prior distributions")
@@ -100,12 +105,12 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
100105
sampler = p.sampler
101106
varinfo = p.varinfo
102107
# TODO: Surely there's a better way of doing this now that we have `SamplingContext`?
103-
vns = _getvns(varinfo, sampler)
104-
for vn in Iterators.flatten(values(vns))
108+
vns = keys(varinfo)
109+
for vn in vns
105110
set_flag!(varinfo, vn, "del")
106111
end
107112
p.model(rng, varinfo, sampler)
108-
return varinfo[sampler]
113+
return varinfo[:]
109114
end
110115

111116
# Mean of prior distribution
@@ -118,7 +123,7 @@ const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} = Turing.L
118123

119124
function (ℓ::ESSLogLikelihood)(f::AbstractVector)
120125
sampler = DynamicPPL.getsampler(ℓ)
121-
varinfo = setindex!!(ℓ.varinfo, f, sampler)
126+
varinfo = DynamicPPL.unflatten(ℓ.varinfo, f)
122127
varinfo = last(DynamicPPL.evaluate!!(ℓ.model, varinfo, sampler))
123128
return getlogp(varinfo)
124129
end

0 commit comments

Comments
 (0)