Skip to content

Commit 28a7c22

Browse files
authored
Update EllipticalSliceSampling (#1492)
1 parent 6d1562a commit 28a7c22

File tree

2 files changed

+58
-43
lines changed

2 files changed

+58
-43
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.15.4"
3+
version = "0.15.5"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -44,7 +44,7 @@ Distributions = "0.23.3, 0.24"
4444
DistributionsAD = "0.6"
4545
DocStringExtensions = "0.8"
4646
DynamicPPL = "0.10.2"
47-
EllipticalSliceSampling = "0.3"
47+
EllipticalSliceSampling = "0.4"
4848
ForwardDiff = "0.10.3"
4949
Libtask = "0.4, 0.5"
5050
LogDensityProblems = "^0.9, 0.10"

src/inference/ess.jl

Lines changed: 56 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -64,67 +64,82 @@ function AbstractMCMC.step(
6464
end
6565

6666
# define previous sampler state
67-
oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi))
67+
# (do not use cache to avoid in-place sampling from prior)
68+
oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi), nothing)
6869

6970
# compute next state
70-
_, state = AbstractMCMC.step(rng, ESSModel(model, spl, vi),
71-
EllipticalSliceSampling.ESS(), oldstate)
71+
sample, state = AbstractMCMC.step(
72+
rng,
73+
EllipticalSliceSampling.ESSModel(
74+
ESSPrior(model, spl, vi), ESSLogLikelihood(model, spl, vi),
75+
),
76+
EllipticalSliceSampling.ESS(),
77+
oldstate,
78+
)
7279

7380
# update sample and log-likelihood
74-
vi[spl] = state.sample
81+
vi[spl] = sample
7582
setlogp!(vi, state.loglikelihood)
7683

7784
return Transition(vi), vi
7885
end
7986

80-
struct ESSModel{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T} <: AbstractMCMC.AbstractModel
87+
# Prior distribution of considered random variable
88+
struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T}
8189
model::M
82-
spl::S
83-
vi::V
90+
sampler::S
91+
varinfo::V
8492
μ::T
85-
end
86-
87-
function ESSModel(model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo)
88-
vns = _getvns(vi, spl)
89-
μ = mapreduce(vcat, vns[1]) do vn
90-
dist = getdist(vi, vn)
91-
vectorize(dist, mean(dist))
93+
94+
function ESSPrior{M,S,V}(model::M, sampler::S, varinfo::V) where {
95+
M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo
96+
}
97+
vns = _getvns(varinfo, sampler)
98+
μ = mapreduce(vcat, vns[1]) do vn
99+
dist = getdist(varinfo, vn)
100+
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
101+
error("[ESS] only supports Gaussian prior distributions")
102+
vectorize(dist, mean(dist))
103+
end
104+
return new{M,S,V,typeof(μ)}(model, sampler, varinfo, μ)
92105
end
93-
94-
ESSModel(model, spl, vi, μ)
95106
end
96107

97-
# sample from the prior
98-
function EllipticalSliceSampling.sample_prior(rng::Random.AbstractRNG, model::ESSModel)
99-
spl = model.spl
100-
vi = model.vi
101-
vns = _getvns(vi, spl)
102-
set_flag!(vi, vns[1][1], "del")
103-
model.model(rng, vi, spl)
104-
return vi[spl]
108+
function ESSPrior(model::Model, sampler::Sampler{<:ESS}, varinfo::AbstractVarInfo)
109+
return ESSPrior{typeof(model),typeof(sampler),typeof(varinfo)}(
110+
model, sampler, varinfo,
111+
)
105112
end
106113

107-
# compute proposal and apply correction for distributions with nonzero mean
108-
function EllipticalSliceSampling.proposal(model::ESSModel, f, ν, θ)
109-
sinθ, cosθ = sincos(θ)
110-
a = 1 - (sinθ + cosθ)
111-
return @. cosθ * f + sinθ * ν + a * model.μ
114+
# Ensure that the prior is a Gaussian distribution (checked in the constructor)
115+
EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true
116+
117+
# Only define out-of-place sampling
118+
function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
119+
sampler = p.sampler
120+
varinfo = p.varinfo
121+
vns = _getvns(varinfo, sampler)
122+
set_flag!(varinfo, vns[1][1], "del")
123+
p.model(rng, varinfo, sampler)
124+
return varinfo[sampler]
112125
end
113126

114-
function EllipticalSliceSampling.proposal!(out, model::ESSModel, f, ν, θ)
115-
sinθ, cosθ = sincos(θ)
116-
a = 1 - (sinθ + cosθ)
117-
@. out = cosθ * f + sinθ * ν + a * model.μ
118-
return out
127+
# Mean of prior distribution
128+
Distributions.mean(p::ESSPrior) = p.μ
129+
130+
# Evaluate log-likelihood of proposals
131+
struct ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo}
132+
model::M
133+
sampler::S
134+
varinfo::V
119135
end
120136

121-
# evaluate log-likelihood
122-
function Distributions.loglikelihood(model::ESSModel, f)
123-
spl = model.spl
124-
vi = model.vi
125-
vi[spl] = f
126-
model.model(vi, spl)
127-
getlogp(vi)
137+
function (ℓ::ESSLogLikelihood)(f)
138+
sampler =.sampler
139+
varinfo =.varinfo
140+
varinfo[sampler] = f
141+
.model(varinfo, sampler)
142+
return getlogp(varinfo)
128143
end
129144

130145
function DynamicPPL.tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn::VarName, inds, vi)

0 commit comments

Comments
 (0)