Skip to content

Commit cea1f7d

Browse files
mhaurupenelopeysm
authored andcommitted
First efforts towards DPPL 0.37 compat, WIP
1 parent 0164e84 commit cea1f7d

File tree

16 files changed

+321
-233
lines changed

16 files changed

+321
-233
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Distributions = "0.25.77"
6464
DistributionsAD = "0.6"
6565
DocStringExtensions = "0.8, 0.9"
6666
DynamicHMC = "3.4"
67-
DynamicPPL = "0.36.3"
67+
DynamicPPL = "0.37"
6868
EllipticalSliceSampling = "0.5, 1, 2"
6969
ForwardDiff = "0.10.3, 1"
7070
Libtask = "0.8.8"

ext/TuringOptimExt.jl

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ function Optim.optimize(
3434
options::Optim.Options=Optim.Options();
3535
kwargs...,
3636
)
37-
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
38-
f = Optimisation.OptimLogDensity(model, ctx)
37+
vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),))
38+
f = Optimisation.OptimLogDensity(model, vi)
3939
init_vals = DynamicPPL.getparams(f.ldf)
4040
optimizer = Optim.LBFGS()
4141
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
@@ -57,8 +57,8 @@ function Optim.optimize(
5757
options::Optim.Options=Optim.Options();
5858
kwargs...,
5959
)
60-
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
61-
f = Optimisation.OptimLogDensity(model, ctx)
60+
vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),))
61+
f = Optimisation.OptimLogDensity(model, vi)
6262
init_vals = DynamicPPL.getparams(f.ldf)
6363
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
6464
end
@@ -74,8 +74,9 @@ function Optim.optimize(
7474
end
7575

7676
function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...)
77-
ctx = Optimisation.OptimizationContext(DynamicPPL.LikelihoodContext())
78-
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
77+
vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),))
78+
f = Optimisation.OptimLogDensity(model, vi)
79+
return _optimize(f, args...; kwargs...)
7980
end
8081

8182
"""
@@ -104,8 +105,8 @@ function Optim.optimize(
104105
options::Optim.Options=Optim.Options();
105106
kwargs...,
106107
)
107-
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
108-
f = Optimisation.OptimLogDensity(model, ctx)
108+
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
109+
f = Optimisation.OptimLogDensity(model, vi)
109110
init_vals = DynamicPPL.getparams(f.ldf)
110111
optimizer = Optim.LBFGS()
111112
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
@@ -127,8 +128,8 @@ function Optim.optimize(
127128
options::Optim.Options=Optim.Options();
128129
kwargs...,
129130
)
130-
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
131-
f = Optimisation.OptimLogDensity(model, ctx)
131+
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
132+
f = Optimisation.OptimLogDensity(model, vi)
132133
init_vals = DynamicPPL.getparams(f.ldf)
133134
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
134135
end
@@ -144,9 +145,11 @@ function Optim.optimize(
144145
end
145146

146147
function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
147-
ctx = Optimisation.OptimizationContext(DynamicPPL.DefaultContext())
148-
return _optimize(Optimisation.OptimLogDensity(model, ctx), args...; kwargs...)
148+
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
149+
f = Optimisation.OptimLogDensity(model, vi)
150+
return _optimize(f, args...; kwargs...)
149151
end
152+
150153
"""
151154
_optimize(f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...)
152155
@@ -166,7 +169,7 @@ function _optimize(
166169
# whether initialisation is really necessary at all
167170
vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals)
168171
vi = DynamicPPL.link(vi, f.ldf.model)
169-
f = Optimisation.OptimLogDensity(f.ldf.model, vi, f.ldf.context; adtype=f.ldf.adtype)
172+
f = Optimisation.OptimLogDensity(f.ldf.model, vi; adtype=f.ldf.adtype)
170173
init_vals = DynamicPPL.getparams(f.ldf)
171174

172175
# Optimize!
@@ -183,9 +186,7 @@ function _optimize(
183186
# Get the optimum in unconstrained space. `getparams` does the invlinking.
184187
vi = f.ldf.varinfo
185188
vi_optimum = DynamicPPL.unflatten(vi, M.minimizer)
186-
logdensity_optimum = Optimisation.OptimLogDensity(
187-
f.ldf.model, vi_optimum, f.ldf.context
188-
)
189+
logdensity_optimum = Optimisation.OptimLogDensity(f.ldf.model, vi_optimum; adtype=f.ldf.adtype)
189190
vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum)
190191
varnames = map(Symbol first, vns_vals_iter)
191192
vals = map(last, vns_vals_iter)

src/mcmc/Inference.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ using DynamicPPL:
2626
SampleFromPrior,
2727
SampleFromUniform,
2828
DefaultContext,
29-
PriorContext,
30-
LikelihoodContext,
31-
SamplingContext,
3229
set_flag!,
3330
unset_flag!
3431
using Distributions, Libtask, Bijectors

src/mcmc/ess.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ function AbstractMCMC.step(
4949
rng,
5050
EllipticalSliceSampling.ESSModel(
5151
ESSPrior(model, spl, vi),
52-
DynamicPPL.LogDensityFunction(
52+
DynamicPPL.LogDensityFunction{:LogLikelihood}(
5353
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
5454
),
5555
),
@@ -59,7 +59,7 @@ function AbstractMCMC.step(
5959

6060
# update sample and log-likelihood
6161
vi = DynamicPPL.unflatten(vi, sample)
62-
vi = setlogp!!(vi, state.loglikelihood)
62+
vi = setloglikelihood!!(vi, state.loglikelihood)
6363

6464
return Transition(model, vi), vi
6565
end
@@ -108,20 +108,12 @@ end
108108
# Mean of prior distribution
109109
Distributions.mean(p::ESSPrior) = p.μ
110110

111-
# Evaluate log-likelihood of proposals
112-
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} =
113-
DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD}
114-
115-
(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ, f)
116-
117111
function DynamicPPL.tilde_assume(
118-
rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
112+
rng::Random.AbstractRNG, ctx::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
119113
)
120-
return DynamicPPL.tilde_assume(
121-
rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi
122-
)
114+
return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi)
123115
end
124116

125-
function DynamicPPL.tilde_observe(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vi)
126-
return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi)
117+
function DynamicPPL.tilde_observe!!(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi)
118+
return DynamicPPL.tilde_observe!!(ctx, SampleFromPrior(), right, left, vn, vi)
127119
end

src/mcmc/gibbs.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context)
3333
#
3434
# Purpose: avoid triggering resampling of variables we're conditioning on.
3535
# - Using standard `DynamicPPL.condition` results in conditioned variables being treated
36-
# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`.
36+
# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe!!`.
3737
# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to
3838
# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable
3939
# rather than only for the "true" observations.
@@ -178,24 +178,26 @@ function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi)
178178
DynamicPPL.tilde_assume(child_context, right, vn, vi)
179179
elseif has_conditioned_gibbs(context, vn)
180180
# Short-circuit the tilde assume if `vn` is present in `context`.
181-
value, lp, _ = DynamicPPL.tilde_assume(
181+
# TODO(mhauru) Fix accumulation here. In this branch anything that gets
182+
# accumulated just gets discarded with `_`.
183+
value, _ = DynamicPPL.tilde_assume(
182184
child_context, right, vn, get_global_varinfo(context)
183185
)
184-
value, lp, vi
186+
value, vi
185187
else
186188
# If the varname has not been conditioned on, nor is it a target variable, its
187189
# presumably a new variable that should be sampled from its prior. We need to add
188190
# this new variable to the global `varinfo` of the context, but not to the local one
189191
# being used by the current sampler.
190-
value, lp, new_global_vi = DynamicPPL.tilde_assume(
192+
value, new_global_vi = DynamicPPL.tilde_assume(
191193
child_context,
192194
DynamicPPL.SampleFromPrior(),
193195
right,
194196
vn,
195197
get_global_varinfo(context),
196198
)
197199
set_global_varinfo!(context, new_global_vi)
198-
value, lp, vi
200+
value, vi
199201
end
200202
end
201203

src/mcmc/hmc.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -516,10 +516,6 @@ function DynamicPPL.assume(
516516
return DynamicPPL.assume(dist, vn, vi)
517517
end
518518

519-
function DynamicPPL.observe(::Sampler{<:Hamiltonian}, d::Distribution, value, vi)
520-
return DynamicPPL.observe(d, value, vi)
521-
end
522-
523519
####
524520
#### Default HMC stepsize and mass matrix adaptor
525521
####

src/mcmc/is.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,3 @@ function DynamicPPL.assume(rng, ::Sampler{<:IS}, dist::Distribution, vn::VarName
5555
end
5656
return r, 0, vi
5757
end
58-
59-
function DynamicPPL.observe(::Sampler{<:IS}, dist::Distribution, value, vi)
60-
return logpdf(dist, value), vi
61-
end

src/mcmc/mh.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,3 @@ function DynamicPPL.assume(
392392
retval = DynamicPPL.assume(rng, SampleFromPrior(), dist, vn, vi)
393393
return retval
394394
end
395-
396-
function DynamicPPL.observe(spl::Sampler{<:MH}, d::Distribution, value, vi)
397-
return DynamicPPL.observe(SampleFromPrior(), d, value, vi)
398-
end

src/mcmc/particle_mcmc.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -450,10 +450,11 @@ function DynamicPPL.assume(
450450
return r, lp, vi
451451
end
452452

453-
function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi)
454-
# NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`.
455-
return logpdf(dist, value), trace_local_varinfo_maybe(vi)
456-
end
453+
# TODO(mhauru) Fix this.
454+
# function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi)
455+
# # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`.
456+
# return logpdf(dist, value), trace_local_varinfo_maybe(vi)
457+
# end
457458

458459
function DynamicPPL.acclogp!!(
459460
context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp
@@ -462,12 +463,13 @@ function DynamicPPL.acclogp!!(
462463
return DynamicPPL.acclogp!!(DynamicPPL.childcontext(context), varinfo_trace, logp)
463464
end
464465

465-
function DynamicPPL.acclogp_observe!!(
466-
context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp
467-
)
468-
Libtask.produce(logp)
469-
return trace_local_varinfo_maybe(varinfo)
470-
end
466+
# TODO(mhauru) Fix this.
467+
# function DynamicPPL.acclogp_observe!!(
468+
# context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, varinfo::AbstractVarInfo, logp
469+
# )
470+
# Libtask.produce(logp)
471+
# return trace_local_varinfo_maybe(varinfo)
472+
# end
471473

472474
# Convenient constructor
473475
function AdvancedPS.Trace(

0 commit comments

Comments
 (0)