@@ -64,67 +64,82 @@ function AbstractMCMC.step(
64
64
end
65
65
66
66
# define previous sampler state
67
- oldstate = EllipticalSliceSampling. ESSState (f, getlogp (vi))
67
+ # (do not use cache to avoid in-place sampling from prior)
68
+ oldstate = EllipticalSliceSampling. ESSState (f, getlogp (vi), nothing )
68
69
69
70
# compute next state
70
- _, state = AbstractMCMC. step (rng, ESSModel (model, spl, vi),
71
- EllipticalSliceSampling. ESS (), oldstate)
71
+ sample, state = AbstractMCMC. step (
72
+ rng,
73
+ EllipticalSliceSampling. ESSModel (
74
+ ESSPrior (model, spl, vi), ESSLogLikelihood (model, spl, vi),
75
+ ),
76
+ EllipticalSliceSampling. ESS (),
77
+ oldstate,
78
+ )
72
79
73
80
# update sample and log-likelihood
74
- vi[spl] = state . sample
81
+ vi[spl] = sample
75
82
setlogp! (vi, state. loglikelihood)
76
83
77
84
return Transition (vi), vi
78
85
end
79
86
80
- struct ESSModel{M<: Model ,S<: Sampler{<:ESS} ,V<: AbstractVarInfo ,T} <: AbstractMCMC.AbstractModel
87
+ # Prior distribution of considered random variable
88
+ struct ESSPrior{M<: Model ,S<: Sampler{<:ESS} ,V<: AbstractVarInfo ,T}
81
89
model:: M
82
- spl :: S
83
- vi :: V
90
+ sampler :: S
91
+ varinfo :: V
84
92
μ:: T
85
- end
86
-
87
- function ESSModel (model:: Model , spl:: Sampler{<:ESS} , vi:: AbstractVarInfo )
88
- vns = _getvns (vi, spl)
89
- μ = mapreduce (vcat, vns[1 ]) do vn
90
- dist = getdist (vi, vn)
91
- vectorize (dist, mean (dist))
93
+
94
+ function ESSPrior {M,S,V} (model:: M , sampler:: S , varinfo:: V ) where {
95
+ M<: Model ,S<: Sampler{<:ESS} ,V<: AbstractVarInfo
96
+ }
97
+ vns = _getvns (varinfo, sampler)
98
+ μ = mapreduce (vcat, vns[1 ]) do vn
99
+ dist = getdist (varinfo, vn)
100
+ EllipticalSliceSampling. isgaussian (typeof (dist)) ||
101
+ error (" [ESS] only supports Gaussian prior distributions" )
102
+ vectorize (dist, mean (dist))
103
+ end
104
+ return new {M,S,V,typeof(μ)} (model, sampler, varinfo, μ)
92
105
end
93
-
94
- ESSModel (model, spl, vi, μ)
95
106
end
96
107
97
- # sample from the prior
98
- function EllipticalSliceSampling. sample_prior (rng:: Random.AbstractRNG , model:: ESSModel )
99
- spl = model. spl
100
- vi = model. vi
101
- vns = _getvns (vi, spl)
102
- set_flag! (vi, vns[1 ][1 ], " del" )
103
- model. model (rng, vi, spl)
104
- return vi[spl]
108
+ function ESSPrior (model:: Model , sampler:: Sampler{<:ESS} , varinfo:: AbstractVarInfo )
109
+ return ESSPrior {typeof(model),typeof(sampler),typeof(varinfo)} (
110
+ model, sampler, varinfo,
111
+ )
105
112
end
106
113
107
- # compute proposal and apply correction for distributions with nonzero mean
108
- function EllipticalSliceSampling. proposal (model:: ESSModel , f, ν, θ)
109
- sinθ, cosθ = sincos (θ)
110
- a = 1 - (sinθ + cosθ)
111
- return @. cosθ * f + sinθ * ν + a * model. μ
114
+ # Ensure that the prior is a Gaussian distribution (checked in the constructor)
115
+ EllipticalSliceSampling. isgaussian (:: Type{<:ESSPrior} ) = true
116
+
117
+ # Only define out-of-place sampling
118
+ function Base. rand (rng:: Random.AbstractRNG , p:: ESSPrior )
119
+ sampler = p. sampler
120
+ varinfo = p. varinfo
121
+ vns = _getvns (varinfo, sampler)
122
+ set_flag! (varinfo, vns[1 ][1 ], " del" )
123
+ p. model (rng, varinfo, sampler)
124
+ return varinfo[sampler]
112
125
end
113
126
114
- function EllipticalSliceSampling. proposal! (out, model:: ESSModel , f, ν, θ)
115
- sinθ, cosθ = sincos (θ)
116
- a = 1 - (sinθ + cosθ)
117
- @. out = cosθ * f + sinθ * ν + a * model. μ
118
- return out
127
+ # Mean of prior distribution
128
+ Distributions. mean (p:: ESSPrior ) = p. μ
129
+
130
+ # Evaluate log-likelihood of proposals
131
+ struct ESSLogLikelihood{M<: Model ,S<: Sampler{<:ESS} ,V<: AbstractVarInfo }
132
+ model:: M
133
+ sampler:: S
134
+ varinfo:: V
119
135
end
120
136
121
- # evaluate log-likelihood
122
- function Distributions. loglikelihood (model:: ESSModel , f)
123
- spl = model. spl
124
- vi = model. vi
125
- vi[spl] = f
126
- model. model (vi, spl)
127
- getlogp (vi)
137
+ function (ℓ:: ESSLogLikelihood )(f)
138
+ sampler = ℓ. sampler
139
+ varinfo = ℓ. varinfo
140
+ varinfo[sampler] = f
141
+ ℓ. model (varinfo, sampler)
142
+ return getlogp (varinfo)
128
143
end
129
144
130
145
function DynamicPPL. tilde (rng, ctx:: DefaultContext , sampler:: Sampler{<:ESS} , right, vn:: VarName , inds, vi)
0 commit comments