Skip to content

Commit a7c78ba

Browse files
torfjeldedevmotionyebai
authored
Reduce usage of sampler (#1936)
* initial work on using less of sampler and more of context * fixed optim and ad * fixed ESS and MH * bump version * fixed typo * moved all LogDensityProblems related to DPPL * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Apply suggestions from code review Co-authored-by: Hong Ge <[email protected]> * make compat bounds correct * fixed bug in MH * fixed MH again * fixed Emcee * fixed emcee * fixed broken tests and removed mentionings of sampler in optimization * fixed bug in optim * added tests for demo models * fixed missing support for certain models in optim * fixed type * fix * use LogDensityModel instead of wrapping in Base.Fix * Revert "use LogDensityModel instead of wrapping in Base.Fix" This reverts commit 596b0b2. * disable failing unsupported models for reverse mode AD frameworks Co-authored-by: David Widmann <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent c9409e2 commit a7c78ba

File tree

10 files changed

+131
-103
lines changed

10 files changed

+131
-103
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ DataStructures = "0.18"
4747
Distributions = "0.23.3, 0.24, 0.25"
4848
DistributionsAD = "0.6"
4949
DocStringExtensions = "0.8, 0.9"
50-
DynamicPPL = "0.21"
50+
DynamicPPL = "0.21.5"
5151
EllipticalSliceSampling = "0.5, 1"
5252
ForwardDiff = "0.10.3"
5353
Libtask = "0.7, 0.8"

src/Turing.jl

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Libtask
88
using Tracker: Tracker
99

1010
import AdvancedVI
11+
using DynamicPPL: DynamicPPL, LogDensityFunction
1112
import DynamicPPL: getspace, NoDist, NamedDist
1213
import LogDensityProblems
1314
import Random
@@ -26,26 +27,6 @@ function setprogress!(progress::Bool)
2627
return progress
2728
end
2829

29-
# Log density function
30-
struct LogDensityFunction{V,M,S,C}
31-
varinfo::V
32-
model::M
33-
sampler::S
34-
context::C
35-
end
36-
37-
function (f::LogDensityFunction)(θ::AbstractVector)
38-
vi_new = DynamicPPL.unflatten(f.varinfo, f.sampler, θ)
39-
return getlogp(last(DynamicPPL.evaluate!!(f.model, vi_new, f.sampler, f.context)))
40-
end
41-
42-
# LogDensityProblems interface
43-
LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) = f(θ)
44-
LogDensityProblems.dimension(f::LogDensityFunction) = length(f.varinfo[f.sampler])
45-
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
46-
return LogDensityProblems.LogDensityOrder{0}()
47-
end
48-
4930
# Standard tag: Improves stacktraces
5031
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
5132
struct TuringTag end
@@ -154,6 +135,7 @@ export @model, # modelling
154135
generated_quantities,
155136
logprior,
156137
logjoint,
138+
LogDensityFunction,
157139

158140
constrained_space, # optimisation interface
159141
MAP,

src/essential/ad.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,18 @@ Find the autodifferentiation backend of the algorithm `alg`.
7777
"""
7878
getADbackend(spl::Sampler) = getADbackend(spl.alg)
7979
getADbackend(::SampleFromPrior) = ADBackend()()
80+
getADbackend(ctx::DynamicPPL.SamplingContext) = getADbackend(ctx.sampler)
81+
getADbackend(ctx::DynamicPPL.AbstractContext) = getADbackend(DynamicPPL.NodeTrait(ctx), ctx)
82+
83+
getADbackend(::DynamicPPL.IsLeaf, ctx::DynamicPPL.AbstractContext) = ADBackend()()
84+
getADbackend(::DynamicPPL.IsParent, ctx::DynamicPPL.AbstractContext) = getADbackend(DynamicPPL.childcontext(ctx))
8085

8186
function LogDensityProblemsAD.ADgradient(ℓ::Turing.LogDensityFunction)
82-
return LogDensityProblemsAD.ADgradient(getADbackend(ℓ.sampler), ℓ)
87+
return LogDensityProblemsAD.ADgradient(getADbackend(ℓ.context), ℓ)
8388
end
8489

8590
function LogDensityProblemsAD.ADgradient(ad::ForwardDiffAD, ℓ::Turing.LogDensityFunction)
86-
θ = .varinfo[ℓ.sampler]
91+
θ = DynamicPPL.getparams(ℓ)
8792
f = Base.Fix1(LogDensityProblems.logdensity, ℓ)
8893

8994
# Define configuration for ForwardDiff.

src/inference/emcee.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ function AbstractMCMC.step(
7474
)
7575
# Generate a log joint function.
7676
vi = state.vi
77-
densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, SampleFromPrior(), DynamicPPL.DefaultContext()))
77+
densitymodel = AMH.DensityModel(
78+
Base.Fix1(LogDensityProblems.logdensity, Turing.LogDensityFunction(model, vi))
79+
)
7880

7981
# Compute the next states.
8082
states = last(AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states))

src/inference/ess.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,10 @@ end
124124
Distributions.mean(p::ESSPrior) = p.μ
125125

126126
# Evaluate log-likelihood of proposals
127-
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext()}
127+
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,<:DynamicPPL.SamplingContext{<:S}}
128128

129129
function (ℓ::ESSLogLikelihood)(f::AbstractVector)
130-
sampler = .sampler
130+
sampler = DynamicPPL.getsampler(ℓ)
131131
varinfo = setindex!!(ℓ.varinfo, f, sampler)
132132
varinfo = last(DynamicPPL.evaluate!!(ℓ.model, varinfo, sampler))
133133
return getlogp(varinfo)

src/inference/mh.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -246,11 +246,11 @@ A log density function for the MH sampler.
246246
247247
This variant uses the `set_namedtuple!` function to update the `VarInfo`.
248248
"""
249-
const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext}
249+
const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,<:DynamicPPL.SamplingContext{<:S}}
250250

251-
function (f::MHLogDensityFunction)(x::NamedTuple)
251+
function LogDensityProblems.logdensity(f::MHLogDensityFunction, x::NamedTuple)
252252
# TODO: Make this work with immutable `f.varinfo` too.
253-
sampler = f.sampler
253+
sampler = DynamicPPL.getsampler(f)
254254
vi = f.varinfo
255255

256256
x_old, lj_old = vi[sampler], getlogp(vi)
@@ -374,7 +374,9 @@ function propose!!(
374374
prev_trans = AMH.Transition(vt, getlogp(vi))
375375

376376
# Make a new transition.
377-
densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
377+
densitymodel = AMH.DensityModel(
378+
Base.Fix1(LogDensityProblems.logdensity, Turing.LogDensityFunction(vi, model, DynamicPPL.SamplingContext(rng, spl)))
379+
)
378380
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
379381

380382
# TODO: Make this compatible with immutable `VarInfo`.
@@ -400,7 +402,9 @@ function propose!!(
400402
prev_trans = AMH.Transition(vals, getlogp(vi))
401403

402404
# Make a new transition.
403-
densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
405+
densitymodel = AMH.DensityModel(
406+
Base.Fix1(LogDensityProblems.logdensity, Turing.LogDensityFunction(vi, model, DynamicPPL.SamplingContext(rng, spl)))
407+
)
404408
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)
405409

406410
return setlogp!!(DynamicPPL.unflatten(vi, spl, trans.params), trans.lp)

src/modes/ModeEstimation.jl

Lines changed: 44 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -46,45 +46,42 @@ DynamicPPL.childcontext(context::OptimizationContext) = context.context
4646
DynamicPPL.setchildcontext(::OptimizationContext, child) = OptimizationContext(child)
4747

4848
# assume
49-
function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, spl, dist, vn, vi)
50-
return DynamicPPL.tilde_assume(ctx, spl, dist, vn, vi)
51-
end
52-
53-
function DynamicPPL.tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, spl, dist, vn, vi)
54-
r = vi[vn]
49+
function DynamicPPL.tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, dist, vn, vi)
50+
r = vi[vn, dist]
5551
return r, 0, vi
5652
end
5753

58-
function DynamicPPL.tilde_assume(ctx::OptimizationContext, spl, dist, vn, vi)
59-
r = vi[vn]
54+
function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi)
55+
r = vi[vn, dist]
6056
return r, Distributions.logpdf(dist, r), vi
6157
end
6258

6359
# dot assume
64-
function DynamicPPL.dot_tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, sampler, right, left, vns, vi)
65-
return DynamicPPL.dot_tilde_assume(ctx, sampler, right, left, vns, vi)
66-
end
67-
68-
function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, sampler::SampleFromPrior, right, left, vns, vi)
60+
function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, right, left, vns, vi)
6961
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
7062
# affect anything.
71-
r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler)
63+
# TODO: Stop using `get_and_set_val!`.
64+
r = DynamicPPL.get_and_set_val!(Random.default_rng(), vi, vns, right, SampleFromPrior())
7265
return r, 0, vi
7366
end
7467

75-
function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, sampler::SampleFromPrior, right, left, vns, vi)
68+
_loglikelihood(dist::Distribution, x) = loglikelihood(dist, x)
69+
_loglikelihood(dists::AbstractArray{<:Distribution}, x) = loglikelihood(arraydist(dists), x)
70+
71+
function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns, vi)
7672
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
7773
# affect anything.
78-
r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler)
79-
return r, loglikelihood(right, r), vi
74+
# TODO: Stop using `get_and_set_val!`.
75+
r = DynamicPPL.get_and_set_val!(Random.default_rng(), vi, vns, right, SampleFromPrior())
76+
return r, _loglikelihood(right, r), vi
8077
end
8178

8279
"""
8380
OptimLogDensity{M<:Model,C<:Context,V<:VarInfo}
8481
8582
A struct that stores the negative log density function of a `DynamicPPL` model.
8683
"""
87-
const OptimLogDensity{M<:Model,C<:OptimizationContext,V<:VarInfo} = Turing.LogDensityFunction{V,M,DynamicPPL.SampleFromPrior,C}
84+
const OptimLogDensity{M<:Model,C<:OptimizationContext,V<:VarInfo} = Turing.LogDensityFunction{V,M,C}
8885

8986
"""
9087
OptimLogDensity(model::Model, context::OptimizationContext)
@@ -93,21 +90,23 @@ Create a callable `OptimLogDensity` struct that evaluates a model using the give
9390
"""
9491
function OptimLogDensity(model::Model, context::OptimizationContext)
9592
init = VarInfo(model)
96-
return Turing.LogDensityFunction(init, model, DynamicPPL.SampleFromPrior(), context)
93+
return Turing.LogDensityFunction(init, model, context)
9794
end
9895

9996
"""
100-
(f::OptimLogDensity)(z)
97+
LogDensityProblems.logdensity(f::OptimLogDensity, z)
10198
10299
Evaluate the negative log joint (with `DefaultContext`) or log likelihood (with `LikelihoodContext`)
103100
at the array `z`.
104101
"""
105102
function (f::OptimLogDensity)(z::AbstractVector)
106-
sampler = f.sampler
107-
varinfo = DynamicPPL.unflatten(f.varinfo, sampler, z)
108-
return -getlogp(last(DynamicPPL.evaluate!!(f.model, varinfo, sampler, f.context)))
103+
varinfo = DynamicPPL.unflatten(f.varinfo, z)
104+
return -getlogp(last(DynamicPPL.evaluate!!(f.model, varinfo, f.context)))
109105
end
110106

107+
# NOTE: This seems a bit weird IMO since this is the _negative_ log-likelihood.
108+
LogDensityProblems.logdensity(f::OptimLogDensity, z::AbstractVector) = f(z)
109+
111110
function (f::OptimLogDensity)(F, G, z)
112111
if G !== nothing
113112
# Calculate negative log joint and its gradient.
@@ -127,7 +126,7 @@ function (f::OptimLogDensity)(F, G, z)
127126

128127
# Only negative log joint requested but no gradient.
129128
if F !== nothing
130-
return f(z)
129+
return LogDensityProblems.logdensity(f, z)
131130
end
132131

133132
return nothing
@@ -140,50 +139,44 @@ end
140139
#################################################
141140

142141
function transform!!(f::OptimLogDensity)
143-
spl = f.sampler
144-
145142
## Check link status of vi in OptimLogDensity
146-
linked = DynamicPPL.islinked(f.varinfo, spl)
143+
linked = DynamicPPL.istrans(f.varinfo)
147144

148145
## transform into constrained or unconstrained space depending on current state of vi
149146
@set! f.varinfo = if !linked
150-
DynamicPPL.link!!(f.varinfo, spl, f.model)
147+
DynamicPPL.link!!(f.varinfo, f.model)
151148
else
152-
DynamicPPL.invlink!!(f.varinfo, spl, f.model)
149+
DynamicPPL.invlink!!(f.varinfo, f.model)
153150
end
154151

155152
return f
156153
end
157154

158155
function transform!!(p::AbstractArray, vi::DynamicPPL.VarInfo, model::DynamicPPL.Model, ::constrained_space{true})
159-
spl = DynamicPPL.SampleFromPrior()
160-
161-
linked = DynamicPPL.islinked(vi, spl)
156+
linked = DynamicPPL.istrans(vi)
162157

163158
!linked && return identity(p) # TODO: why do we do `identity` here?
164-
vi = DynamicPPL.setindex!!(vi, p, spl)
165-
vi = DynamicPPL.invlink!!(vi, spl, model)
166-
p .= vi[spl]
159+
vi = DynamicPPL.unflatten(vi, p)
160+
vi = DynamicPPL.invlink!!(vi, model)
161+
p .= vi[:]
167162

168163
# If linking mutated, we need to link once more.
169-
linked && DynamicPPL.link!!(vi, spl, model)
164+
linked && DynamicPPL.link!!(vi, model)
170165

171166
return p
172167
end
173168

174169
function transform!!(p::AbstractArray, vi::DynamicPPL.VarInfo, model::DynamicPPL.Model, ::constrained_space{false})
175-
spl = DynamicPPL.SampleFromPrior()
176-
177-
linked = DynamicPPL.islinked(vi, spl)
170+
linked = DynamicPPL.istrans(vi)
178171
if linked
179-
vi = DynamicPPL.invlink!!(vi, spl, model)
172+
vi = DynamicPPL.invlink!!(vi, model)
180173
end
181-
vi = DynamicPPL.setindex!!(vi, p, spl)
182-
vi = DynamicPPL.link!!(vi, spl, model)
183-
p .= vi[spl]
174+
vi = DynamicPPL.unflatten(vi, p)
175+
vi = DynamicPPL.link!!(vi, model)
176+
p .= vi[:]
184177

185178
# If linking mutated, we need to link once more.
186-
!linked && DynamicPPL.invlink!!(vi, spl, model)
179+
!linked && DynamicPPL.invlink!!(vi, model)
187180

188181
return p
189182
end
@@ -208,26 +201,26 @@ end
208201

209202
function (t::AbstractTransform)(p::AbstractArray)
210203
return transform(p, t.vi, t.model, t.space)
211-
end
204+
end
212205

213206
function (t::Init)()
214207
return t.vi[DynamicPPL.SampleFromPrior()]
215208
end
216209

217210
function get_parameter_bounds(model::DynamicPPL.Model)
218211
vi = DynamicPPL.VarInfo(model)
219-
spl = DynamicPPL.SampleFromPrior()
220212

221213
## Check link status of vi
222-
linked = DynamicPPL.islinked(vi, spl)
214+
linked = DynamicPPL.istrans(vi)
223215

224216
## transform into unconstrained
225217
if !linked
226-
vi = DynamicPPL.link!!(vi, spl, model)
218+
vi = DynamicPPL.link!!(vi, model)
227219
end
228-
229-
lb = transform(fill(-Inf,length(vi[DynamicPPL.SampleFromPrior()])), vi, model, constrained_space{true}())
230-
ub = transform(fill(Inf,length(vi[DynamicPPL.SampleFromPrior()])), vi, model, constrained_space{true}())
220+
221+
d = length(vi[:])
222+
lb = transform(fill(-Inf, d), vi, model, constrained_space{true}())
223+
ub = transform(fill(Inf, d), vi, model, constrained_space{true}())
231224

232225
return lb, ub
233226
end

0 commit comments

Comments
 (0)