Skip to content

Commit 6e32434

Browse files
committed
First efforts towards DPPL 0.37 compat, WIP
1 parent 05110bd commit 6e32434

File tree

16 files changed

+147
-254
lines changed

16 files changed

+147
-254
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ Distributions = "0.25.77"
6262
DistributionsAD = "0.6"
6363
DocStringExtensions = "0.8, 0.9"
6464
DynamicHMC = "3.4"
65-
DynamicPPL = "0.36"
65+
DynamicPPL = "0.37"
6666
EllipticalSliceSampling = "0.5, 1, 2"
6767
ForwardDiff = "0.10.3"
6868
Libtask = "0.8.8"

ext/TuringOptimExt.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ function Optim.optimize(
3434
options::Optim.Options=Optim.Options();
3535
kwargs...,
3636
)
37-
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
38-
f = Optimisation.OptimLogDensity(model, ctx)
37+
vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),))
38+
f = Optimisation.OptimLogDensity(model, vi)
3939
init_vals = DynamicPPL.getparams(f.ldf)
4040
optimizer = Optim.LBFGS()
4141
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
@@ -57,8 +57,8 @@ function Optim.optimize(
5757
options::Optim.Options=Optim.Options();
5858
kwargs...,
5959
)
60-
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
61-
f = Optimisation.OptimLogDensity(model, ctx)
60+
vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),))
61+
f = Optimisation.OptimLogDensity(model, vi)
6262
init_vals = DynamicPPL.getparams(f.ldf)
6363
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
6464
end
@@ -74,8 +74,9 @@ function Optim.optimize(
7474
end
7575

7676
function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...)
77-
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
78-
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
77+
vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),))
78+
f = Optimisation.OptimLogDensity(model, vi)
79+
return _optimize(f, args...; kwargs...)
7980
end
8081

8182
"""
@@ -104,8 +105,8 @@ function Optim.optimize(
104105
options::Optim.Options=Optim.Options();
105106
kwargs...,
106107
)
107-
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
108-
f = Optimisation.OptimLogDensity(model, ctx)
108+
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
109+
f = Optimisation.OptimLogDensity(model, vi)
109110
init_vals = DynamicPPL.getparams(f.ldf)
110111
optimizer = Optim.LBFGS()
111112
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
@@ -127,8 +128,8 @@ function Optim.optimize(
127128
options::Optim.Options=Optim.Options();
128129
kwargs...,
129130
)
130-
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
131-
f = Optimisation.OptimLogDensity(model, ctx)
131+
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
132+
f = Optimisation.OptimLogDensity(model, vi)
132133
init_vals = DynamicPPL.getparams(f.ldf)
133134
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
134135
end
@@ -144,9 +145,11 @@ function Optim.optimize(
144145
end
145146

146147
function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
147-
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
148-
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
148+
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
149+
f = Optimisation.OptimLogDensity(model, vi)
150+
return _optimize(f, args...; kwargs...)
149151
end
152+
150153
"""
151154
_optimize(f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...)
152155
@@ -166,7 +169,7 @@ function _optimize(
166169
# whether initialisation is really necessary at all
167170
vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals)
168171
vi = DynamicPPL.link(vi, f.ldf.model)
169-
f = Optimisation.OptimLogDensity(f.ldf.model, vi, f.ldf.context; adtype=f.ldf.adtype)
172+
f = Optimisation.OptimLogDensity(f.ldf.model, vi; adtype=f.ldf.adtype)
170173
init_vals = DynamicPPL.getparams(f.ldf)
171174

172175
# Optimize!
@@ -183,9 +186,7 @@ function _optimize(
183186
# Get the optimum in unconstrained space. `getparams` does the invlinking.
184187
vi = f.ldf.varinfo
185188
vi_optimum = DynamicPPL.unflatten(vi, M.minimizer)
186-
logdensity_optimum = Optimisation.OptimLogDensity(
187-
f.ldf.model, vi_optimum, f.ldf.context
188-
)
189+
logdensity_optimum = Optimisation.OptimLogDensity(f.ldf.model, vi_optimum; adtype=f.ldf.adtype)
189190
vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum)
190191
varnames = map(Symbol first, vns_vals_iter)
191192
vals = map(last, vns_vals_iter)

src/mcmc/Inference.jl

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@ using DynamicPPL:
2222
SampleFromPrior,
2323
SampleFromUniform,
2424
DefaultContext,
25-
PriorContext,
26-
LikelihoodContext,
2725
set_flag!,
2826
unset_flag!
2927
using Distributions, Libtask, Bijectors
@@ -75,7 +73,6 @@ export InferenceAlgorithm,
7573
RepeatSampler,
7674
Prior,
7775
assume,
78-
observe,
7976
predict,
8077
externalsampler
8178

@@ -182,12 +179,10 @@ function AbstractMCMC.step(
182179
state=nothing;
183180
kwargs...,
184181
)
182+
vi = VarInfo()
183+
vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.LogPriorAccumulator(),))
185184
vi = last(
186-
DynamicPPL.evaluate!!(
187-
model,
188-
VarInfo(),
189-
SamplingContext(rng, DynamicPPL.SampleFromPrior(), DynamicPPL.PriorContext()),
190-
),
185+
DynamicPPL.evaluate!!(model, vi, SamplingContext(rng, DynamicPPL.SampleFromPrior())),
191186
)
192187
return vi, nothing
193188
end

src/mcmc/ess.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function AbstractMCMC.step(
4949
rng,
5050
EllipticalSliceSampling.ESSModel(
5151
ESSPrior(model, spl, vi),
52-
DynamicPPL.LogDensityFunction(
52+
DynamicPPL.LogDensityFunction{:LogLikelihood}(
5353
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
5454
),
5555
),
@@ -59,7 +59,7 @@ function AbstractMCMC.step(
5959

6060
# update sample and log-likelihood
6161
vi = DynamicPPL.unflatten(vi, sample)
62-
vi = setlogp!!(vi, state.loglikelihood)
62+
vi = setloglikelihood!!(vi, state.loglikelihood)
6363

6464
return Transition(model, vi), vi
6565
end
@@ -108,20 +108,12 @@ end
108108
# Mean of prior distribution
109109
Distributions.mean(p::ESSPrior) = p.μ
110110

111-
# Evaluate log-likelihood of proposals
112-
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} =
113-
DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD}
114-
115-
(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ, f)
116-
117111
function DynamicPPL.tilde_assume(
118-
rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
112+
rng::Random.AbstractRNG, ctx::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
119113
)
120-
return DynamicPPL.tilde_assume(
121-
rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi
122-
)
114+
return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi)
123115
end
124116

125-
function DynamicPPL.tilde_observe(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vi)
126-
return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi)
117+
function DynamicPPL.tilde_observe!!(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi)
118+
return DynamicPPL.tilde_observe!!(ctx, SampleFromPrior(), right, left, vn, vi)
127119
end

src/mcmc/gibbs.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context)
3232
#
3333
# Purpose: avoid triggering resampling of variables we're conditioning on.
3434
# - Using standard `DynamicPPL.condition` results in conditioned variables being treated
35-
# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`.
35+
# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe!!`.
3636
# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to
3737
# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable
3838
# rather than only for the "true" observations.
@@ -177,24 +177,26 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
177177
DynamicPPL.tilde_assume(child_context, right, vn, vi)
178178
elseif has_conditioned_gibbs(context, vn)
179179
# Short-circuit the tilde assume if `vn` is present in `context`.
180-
value, lp, _ = DynamicPPL.tilde_assume(
180+
# TODO(mhauru) Fix accumulation here. In this branch anything that gets
181+
# accumulated just gets discarded with `_`.
182+
value, _ = DynamicPPL.tilde_assume(
181183
child_context, right, vn, get_global_varinfo(context)
182184
)
183-
value, lp, vi
185+
value, vi
184186
else
185187
# If the varname has not been conditioned on, nor is it a target variable, its
186188
# presumably a new variable that should be sampled from its prior. We need to add
187189
# this new variable to the global `varinfo` of the context, but not to the local one
188190
# being used by the current sampler.
189-
value, lp, new_global_vi = DynamicPPL.tilde_assume(
191+
value, new_global_vi = DynamicPPL.tilde_assume(
190192
child_context,
191193
DynamicPPL.SampleFromPrior(),
192194
right,
193195
vn,
194196
get_global_varinfo(context),
195197
)
196198
set_global_varinfo!(context, new_global_vi)
197-
value, lp, vi
199+
value, vi
198200
end
199201
end
200202

src/mcmc/hmc.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -501,10 +501,6 @@ function DynamicPPL.assume(
501501
return DynamicPPL.assume(dist, vn, vi)
502502
end
503503

504-
function DynamicPPL.observe(::Sampler{<:Hamiltonian}, d::Distribution, value, vi)
505-
return DynamicPPL.observe(d, value, vi)
506-
end
507-
508504
####
509505
#### Default HMC stepsize and mass matrix adaptor
510506
####

src/mcmc/is.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,3 @@ function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName
5555
end
5656
return r, 0, vi
5757
end
58-
59-
function DynamicPPL.observe(::Sampler{<:IS}, dist::Distribution, value, vi)
60-
return logpdf(dist, value), vi
61-
end

src/mcmc/mh.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,3 @@ function DynamicPPL.assume(
390390
retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi)
391391
return retval
392392
end
393-
394-
function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi)
395-
return DynamicPPL.observe(SampleFromPrior(), d, value, vi)
396-
end

src/mcmc/particle_mcmc.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -379,10 +379,11 @@ function DynamicPPL.assume(
379379
return r, lp, vi
380380
end
381381

382-
function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi)
383-
# NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`.
384-
return logpdf(dist, value), trace_local_varinfo_maybe(vi)
385-
end
382+
# TODO(mhauru) Fix this.
383+
# function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi)
384+
# # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`.
385+
# return logpdf(dist, value), trace_local_varinfo_maybe(vi)
386+
# end
386387

387388
function DynamicPPL.acclogp!!(
388389
context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp
@@ -391,12 +392,13 @@ function DynamicPPL.acclogp!!(
391392
return DynamicPPL.acclogp!!(DynamicPPL.childcontext(context), varinfo_trace, logp)
392393
end
393394

394-
function DynamicPPL.acclogp_observe!!(
395-
context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp
396-
)
397-
Libtask.produce(logp)
398-
return trace_local_varinfo_maybe(varinfo)
399-
end
395+
# TODO(mhauru) Fix this.
396+
# function DynamicPPL.acclogp_observe!!(
397+
# context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp
398+
# )
399+
# Libtask.produce(logp)
400+
# return trace_local_varinfo_maybe(varinfo)
401+
# end
400402

401403
# Convenient constructor
402404
function AdvancedPS.Trace(

0 commit comments

Comments
 (0)