Skip to content

Commit 0af8725

Browse files
committed
Simplify ESS
1 parent 06fec2d commit 0af8725

File tree

1 file changed

+19
-24
lines changed

1 file changed

+19
-24
lines changed

src/mcmc/ess.jl

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ struct ESS <: InferenceAlgorithm end
2424

2525
# always accept in the first step
2626
function DynamicPPL.initialstep(
27-
rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs...
27+
rng::AbstractRNG, model::Model, ::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs...
2828
)
2929
for vn in keys(vi)
3030
dist = getdist(vi, vn)
@@ -35,7 +35,7 @@ function DynamicPPL.initialstep(
3535
end
3636

3737
function AbstractMCMC.step(
38-
rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs...
38+
rng::AbstractRNG, model::Model, ::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs...
3939
)
4040
# obtain previous sample
4141
f = vi[:]
@@ -47,12 +47,7 @@ function AbstractMCMC.step(
4747
# compute next state
4848
sample, state = AbstractMCMC.step(
4949
rng,
50-
EllipticalSliceSampling.ESSModel(
51-
ESSPrior(model, spl, vi),
52-
ESSLikelihood(
53-
DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, vi)
54-
),
55-
),
50+
EllipticalSliceSampling.ESSModel(ESSPrior(model, vi), ESSLikelihood(model, vi)),
5651
EllipticalSliceSampling.ESS(),
5752
oldstate,
5853
)
@@ -65,36 +60,28 @@ function AbstractMCMC.step(
6560
end
6661

6762
# Prior distribution of considered random variable
68-
struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T}
63+
struct ESSPrior{M<:Model,V<:AbstractVarInfo,T}
6964
model::M
70-
sampler::S
7165
varinfo::V
7266
μ::T
7367

74-
function ESSPrior{M,S,V}(
75-
model::M, sampler::S, varinfo::V
76-
) where {M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo}
68+
function ESSPrior(model::Model, varinfo::AbstractVarInfo)
7769
vns = keys(varinfo)
7870
μ = mapreduce(vcat, vns) do vn
7971
dist = getdist(varinfo, vn)
8072
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
8173
error("[ESS] only supports Gaussian prior distributions")
8274
DynamicPPL.tovec(mean(dist))
8375
end
84-
return new{M,S,V,typeof(μ)}(model, sampler, varinfo, μ)
76+
return new{typeof(model),typeof(varinfo),typeof(μ)}(model, varinfo, μ)
8577
end
8678
end
8779

88-
function ESSPrior(model::Model, sampler::Sampler{<:ESS}, varinfo::AbstractVarInfo)
89-
return ESSPrior{typeof(model),typeof(sampler),typeof(varinfo)}(model, sampler, varinfo)
90-
end
91-
9280
# Ensure that the prior is a Gaussian distribution (checked in the constructor)
9381
EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true
9482

9583
# Only define out-of-place sampling
9684
function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
97-
sampler = p.sampler
9885
varinfo = p.varinfo
9986
# TODO: Surely there's a better way of doing this now that we have `SamplingContext`?
10087
# TODO(DPPL0.37/penelopeysm): This can be replaced with `init!!(p.model,
@@ -105,16 +92,24 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
10592
for vn in vns
10693
set_flag!(varinfo, vn, "del")
10794
end
108-
p.model(rng, varinfo, sampler)
95+
p.model(rng, varinfo)
10996
return varinfo[:]
11097
end
11198

11299
# Mean of prior distribution
113100
Distributions.mean(p::ESSPrior) = p.μ
114101

115-
# Evaluate log-likelihood of proposals
116-
struct ESSLogLikelihood{M<:Model,V<:AbstractVarInfo,AD<:ADTypes.AbstractADType}
117-
ldf::DynamicPPL.LogDensityFunction{M,V,AD}
102+
# Evaluate log-likelihood of proposals. We need this struct because
103+
# EllipticalSliceSampling.jl expects a callable struct / a function as its
104+
# likelihood.
105+
struct ESSLikelihood{M<:Model,V<:AbstractVarInfo}
106+
ldf::DynamicPPL.LogDensityFunction{M,V}
107+
108+
# Force usage of `getloglikelihood` in inner constructor
109+
function ESSLogLikelihood(model::Model, varinfo::AbstractVarInfo)
110+
ldf = DynamicPPL.LogDensityFunction(model, DynamicPPL.getloglikelihood, varinfo)
111+
return new{typeof(model),typeof(varinfo)}(ldf)
112+
end
118113
end
119114

120-
(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ.ldf, f)
115+
(ℓ::ESSLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ.ldf, f)

0 commit comments

Comments
 (0)