@@ -24,7 +24,7 @@ struct ESS <: InferenceAlgorithm end
24
24
25
25
# always accept in the first step
26
26
function DynamicPPL. initialstep (
27
- rng:: AbstractRNG , model:: Model , spl :: Sampler{<:ESS} , vi:: AbstractVarInfo ; kwargs...
27
+ rng:: AbstractRNG , model:: Model , :: Sampler{<:ESS} , vi:: AbstractVarInfo ; kwargs...
28
28
)
29
29
for vn in keys (vi)
30
30
dist = getdist (vi, vn)
@@ -35,7 +35,7 @@ function DynamicPPL.initialstep(
35
35
end
36
36
37
37
function AbstractMCMC. step (
38
- rng:: AbstractRNG , model:: Model , spl :: Sampler{<:ESS} , vi:: AbstractVarInfo ; kwargs...
38
+ rng:: AbstractRNG , model:: Model , :: Sampler{<:ESS} , vi:: AbstractVarInfo ; kwargs...
39
39
)
40
40
# obtain previous sample
41
41
f = vi[:]
@@ -47,12 +47,7 @@ function AbstractMCMC.step(
47
47
# compute next state
48
48
sample, state = AbstractMCMC. step (
49
49
rng,
50
- EllipticalSliceSampling. ESSModel (
51
- ESSPrior (model, spl, vi),
52
- ESSLikelihood (
53
- DynamicPPL. LogDensityFunction (model, DynamicPPL. getloglikelihood, vi)
54
- ),
55
- ),
50
+ EllipticalSliceSampling. ESSModel (ESSPrior (model, vi), ESSLikelihood (model, vi)),
56
51
EllipticalSliceSampling. ESS (),
57
52
oldstate,
58
53
)
@@ -65,36 +60,28 @@ function AbstractMCMC.step(
65
60
end
66
61
67
62
# Prior distribution of considered random variable
68
- struct ESSPrior{M<: Model ,S <: Sampler{<:ESS} , V<: AbstractVarInfo ,T}
63
+ struct ESSPrior{M<: Model ,V<: AbstractVarInfo ,T}
69
64
model:: M
70
- sampler:: S
71
65
varinfo:: V
72
66
μ:: T
73
67
74
- function ESSPrior {M,S,V} (
75
- model:: M , sampler:: S , varinfo:: V
76
- ) where {M<: Model ,S<: Sampler{<:ESS} ,V<: AbstractVarInfo }
68
+ function ESSPrior (model:: Model , varinfo:: AbstractVarInfo )
77
69
vns = keys (varinfo)
78
70
μ = mapreduce (vcat, vns) do vn
79
71
dist = getdist (varinfo, vn)
80
72
EllipticalSliceSampling. isgaussian (typeof (dist)) ||
81
73
error (" [ESS] only supports Gaussian prior distributions" )
82
74
DynamicPPL. tovec (mean (dist))
83
75
end
84
- return new {M,S,V, typeof(μ)} (model, sampler , varinfo, μ)
76
+ return new {typeof(model),typeof(varinfo), typeof(μ)} (model, varinfo, μ)
85
77
end
86
78
end
87
79
88
- function ESSPrior (model:: Model , sampler:: Sampler{<:ESS} , varinfo:: AbstractVarInfo )
89
- return ESSPrior {typeof(model),typeof(sampler),typeof(varinfo)} (model, sampler, varinfo)
90
- end
91
-
92
80
# Ensure that the prior is a Gaussian distribution (checked in the constructor)
93
81
EllipticalSliceSampling. isgaussian (:: Type{<:ESSPrior} ) = true
94
82
95
83
# Only define out-of-place sampling
96
84
function Base. rand (rng:: Random.AbstractRNG , p:: ESSPrior )
97
- sampler = p. sampler
98
85
varinfo = p. varinfo
99
86
# TODO : Surely there's a better way of doing this now that we have `SamplingContext`?
100
87
# TODO (DPPL0.37/penelopeysm): This can be replaced with `init!!(p.model,
@@ -105,16 +92,24 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
105
92
for vn in vns
106
93
set_flag! (varinfo, vn, " del" )
107
94
end
108
- p. model (rng, varinfo, sampler )
95
+ p. model (rng, varinfo)
109
96
return varinfo[:]
110
97
end
111
98
112
99
# Mean of prior distribution
113
100
Distributions. mean (p:: ESSPrior ) = p. μ
114
101
115
- # Evaluate log-likelihood of proposals
116
- struct ESSLogLikelihood{M<: Model ,V<: AbstractVarInfo ,AD<: ADTypes.AbstractADType }
117
- ldf:: DynamicPPL.LogDensityFunction{M,V,AD}
102
+ # Evaluate log-likelihood of proposals. We need this struct because
103
+ # EllipticalSliceSampling.jl expects a callable struct / a function as its
104
+ # likelihood.
105
+ struct ESSLikelihood{M<: Model ,V<: AbstractVarInfo }
106
+ ldf:: DynamicPPL.LogDensityFunction{M,V}
107
+
108
+ # Force usage of `getloglikelihood` in inner constructor
109
+ function ESSLogLikelihood (model:: Model , varinfo:: AbstractVarInfo )
110
+ ldf = DynamicPPL. LogDensityFunction (model, DynamicPPL. getloglikelihood, varinfo)
111
+ return new {typeof(model),typeof(varinfo)} (ldf)
112
+ end
118
113
end
119
114
120
- (ℓ:: ESSLogLikelihood )(f:: AbstractVector ) = LogDensityProblems. logdensity (ℓ. ldf, f)
115
+ (ℓ:: ESSLikelihood )(f:: AbstractVector ) = LogDensityProblems. logdensity (ℓ. ldf, f)
0 commit comments