@@ -27,15 +27,20 @@ function DynamicPPL.initialstep(
27
27
rng:: AbstractRNG , model:: Model , spl:: Sampler{<:ESS} , vi:: AbstractVarInfo ; kwargs...
28
28
)
29
29
# Sanity check
30
- vns = _getvns (vi, spl)
31
- length (vns) == 1 ||
32
- error (" [ESS] does only support one variable ($(length (vns)) variables specified)" )
33
- for vn in only (vns)
30
+ # TODO (mhauru) What is the point of the first check? Why is it relevant that if there
31
+ # are multiple variables they are all under the same symbol?
32
+ vn_syms = syms (vi)
33
+ if length (vn_syms) != 1
34
+ msg = """
35
+ ESS only supports one variable symbol ($(length (vn_syms)) variables specified)\
36
+ """
37
+ error (msg)
38
+ end
39
+ for vn in keys (vi)
34
40
dist = getdist (vi, vn)
35
41
EllipticalSliceSampling. isgaussian (typeof (dist)) ||
36
- error (" [ ESS] only supports Gaussian prior distributions" )
42
+ error (" ESS only supports Gaussian prior distributions" )
37
43
end
38
-
39
44
return Transition (model, vi), vi
40
45
end
41
46
@@ -61,7 +66,7 @@ function AbstractMCMC.step(
61
66
)
62
67
63
68
# update sample and log-likelihood
64
- vi = setindex!! (vi, sample, spl )
69
+ vi = DynamicPPL . unflatten (vi, sample)
65
70
vi = setlogp!! (vi, state. loglikelihood)
66
71
67
72
return Transition (model, vi), vi
@@ -77,8 +82,8 @@ struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T}
77
82
function ESSPrior {M,S,V} (
78
83
model:: M , sampler:: S , varinfo:: V
79
84
) where {M<: Model ,S<: Sampler{<:ESS} ,V<: AbstractVarInfo }
80
- vns = _getvns (varinfo, sampler )
81
- μ = mapreduce (vcat, vns[ 1 ] ) do vn
85
+ vns = keys (varinfo)
86
+ μ = mapreduce (vcat, vns) do vn
82
87
dist = getdist (varinfo, vn)
83
88
EllipticalSliceSampling. isgaussian (typeof (dist)) ||
84
89
error (" [ESS] only supports Gaussian prior distributions" )
@@ -100,12 +105,12 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
100
105
sampler = p. sampler
101
106
varinfo = p. varinfo
102
107
# TODO : Surely there's a better way of doing this now that we have `SamplingContext`?
103
- vns = _getvns (varinfo, sampler )
104
- for vn in Iterators . flatten ( values ( vns))
108
+ vns = keys (varinfo)
109
+ for vn in vns
105
110
set_flag! (varinfo, vn, " del" )
106
111
end
107
112
p. model (rng, varinfo, sampler)
108
- return varinfo[sampler ]
113
+ return varinfo[: ]
109
114
end
110
115
111
116
# Mean of prior distribution
@@ -118,7 +123,7 @@ const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} = Turing.L
118
123
119
124
function (ℓ:: ESSLogLikelihood )(f:: AbstractVector )
120
125
sampler = DynamicPPL. getsampler (ℓ)
121
- varinfo = setindex!! (ℓ. varinfo, f, sampler )
126
+ varinfo = DynamicPPL . unflatten (ℓ. varinfo, f)
122
127
varinfo = last (DynamicPPL. evaluate!! (ℓ. model, varinfo, sampler))
123
128
return getlogp (varinfo)
124
129
end
0 commit comments