Skip to content

Commit 7948bc5

Browse files
committed
More DPPL 0.37 compat work, WIP
1 parent 6e32434 commit 7948bc5

File tree

9 files changed

+161
-93
lines changed

9 files changed

+161
-93
lines changed

ext/TuringOptimExt.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ function Optim.optimize(
3434
options::Optim.Options=Optim.Options();
3535
kwargs...,
3636
)
37-
vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),))
38-
f = Optimisation.OptimLogDensity(model, vi)
37+
f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood)
3938
init_vals = DynamicPPL.getparams(f.ldf)
4039
optimizer = Optim.LBFGS()
4140
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
@@ -57,8 +56,7 @@ function Optim.optimize(
5756
options::Optim.Options=Optim.Options();
5857
kwargs...,
5958
)
60-
vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),))
61-
f = Optimisation.OptimLogDensity(model, vi)
59+
f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood)
6260
init_vals = DynamicPPL.getparams(f.ldf)
6361
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
6462
end
@@ -74,8 +72,7 @@ function Optim.optimize(
7472
end
7573

7674
function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...)
77-
vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),))
78-
f = Optimisation.OptimLogDensity(model, vi)
75+
f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood)
7976
return _optimize(f, args...; kwargs...)
8077
end
8178

@@ -105,8 +102,7 @@ function Optim.optimize(
105102
options::Optim.Options=Optim.Options();
106103
kwargs...,
107104
)
108-
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
109-
f = Optimisation.OptimLogDensity(model, vi)
105+
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
110106
init_vals = DynamicPPL.getparams(f.ldf)
111107
optimizer = Optim.LBFGS()
112108
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
@@ -128,8 +124,7 @@ function Optim.optimize(
128124
options::Optim.Options=Optim.Options();
129125
kwargs...,
130126
)
131-
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
132-
f = Optimisation.OptimLogDensity(model, vi)
127+
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
133128
init_vals = DynamicPPL.getparams(f.ldf)
134129
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
135130
end
@@ -145,8 +140,7 @@ function Optim.optimize(
145140
end
146141

147142
function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
148-
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
149-
f = Optimisation.OptimLogDensity(model, vi)
143+
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
150144
return _optimize(f, args...; kwargs...)
151145
end
152146

@@ -169,7 +163,9 @@ function _optimize(
169163
# whether initialisation is really necessary at all
170164
vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals)
171165
vi = DynamicPPL.link(vi, f.ldf.model)
172-
f = Optimisation.OptimLogDensity(f.ldf.model, vi; adtype=f.ldf.adtype)
166+
f = Optimisation.OptimLogDensity(
167+
f.ldf.model, f.ldf.getlogdensity, vi; adtype=f.ldf.adtype
168+
)
173169
init_vals = DynamicPPL.getparams(f.ldf)
174170

175171
# Optimize!
@@ -186,7 +182,9 @@ function _optimize(
186182
# Get the optimum in unconstrained space. `getparams` does the invlinking.
187183
vi = f.ldf.varinfo
188184
vi_optimum = DynamicPPL.unflatten(vi, M.minimizer)
189-
logdensity_optimum = Optimisation.OptimLogDensity(f.ldf.model, vi_optimum; adtype=f.ldf.adtype)
185+
logdensity_optimum = Optimisation.OptimLogDensity(
186+
f.ldf.model, f.ldf.getlogdensity, vi_optimum; adtype=f.ldf.adtype
187+
)
190188
vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum)
191189
varnames = map(Symbol first, vns_vals_iter)
192190
vals = map(last, vns_vals_iter)

src/essential/container.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ function AdvancedPS.advance!(
2828
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false
2929
)
3030
# Make sure we load/reset the rng in the new replaying mechanism
31-
DynamicPPL.increment_num_produce!(trace.model.f.varinfo)
31+
# TODO(mhauru) Stop ignoring the return value.
32+
DynamicPPL.increment_num_produce!!(trace.model.f.varinfo)
3233
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
3334
score = consume(trace.model.ctask)
3435
if score === nothing
@@ -44,11 +45,13 @@ function AdvancedPS.delete_retained!(trace::TracedModel)
4445
end
4546

4647
function AdvancedPS.reset_model(trace::TracedModel)
47-
DynamicPPL.reset_num_produce!(trace.varinfo)
48+
new_vi = DynamicPPL.reset_num_produce!!(trace.varinfo)
49+
trace = TracedModel(trace.model, trace.sampler, new_vi, trace.evaluator)
4850
return trace
4951
end
5052

5153
function AdvancedPS.reset_logprob!(trace::TracedModel)
54+
# TODO(mhauru) Stop ignoring the return value.
5255
DynamicPPL.resetlogp!!(trace.model.varinfo)
5356
return trace
5457
end

src/mcmc/Inference.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ using DynamicPPL:
1414
push!!,
1515
setlogp!!,
1616
getlogp,
17+
getlogjoint,
1718
VarName,
1819
getsym,
1920
getdist,
@@ -182,7 +183,7 @@ function AbstractMCMC.step(
182183
vi = VarInfo()
183184
vi = DynamicPPL.setaccs!!(vi, (DynamicPPL.LogPriorAccumulator(),))
184185
vi = last(
185-
DynamicPPL.evaluate!!(model, vi, SamplingContext(rng, DynamicPPL.SampleFromPrior())),
186+
DynamicPPL.evaluate!!(model, vi, SamplingContext(rng, DynamicPPL.SampleFromPrior()))
186187
)
187188
return vi, nothing
188189
end
@@ -223,7 +224,7 @@ end
223224
Transition(θ, lp) = Transition(θ, lp, nothing)
224225
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t)
225226
θ = getparams(model, vi)
226-
lp = getlogp(vi)
227+
lp = getlogjoint(vi)
227228
return Transition(θ, lp, getstats(t))
228229
end
229230

@@ -236,10 +237,10 @@ function metadata(t::Transition)
236237
end
237238
end
238239

239-
DynamicPPL.getlogp(t::Transition) = t.lp
240+
DynamicPPL.getlogjoint(t::Transition) = t.lp
240241

241242
# Metadata of VarInfo object
242-
metadata(vi::AbstractVarInfo) = (lp=getlogp(vi),)
243+
metadata(vi::AbstractVarInfo) = (lp=getlogjoint(vi),)
243244

244245
# TODO: Implement additional checks for certain samplers, e.g.
245246
# HMC not supporting discrete parameters.
@@ -376,7 +377,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
376377
end
377378

378379
function get_transition_extras(ts::AbstractVector{<:VarInfo})
379-
valmat = reshape([getlogp(t) for t in ts], :, 1)
380+
valmat = reshape([getlogjoint(t) for t in ts], :, 1)
380381
return [:lp], valmat
381382
end
382383

@@ -589,7 +590,7 @@ julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m`
589590
590591
julia> transitions = Turing.Inference.transitions_from_chain(m, chain);
591592
592-
julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints
593+
julia> [Turing.Inference.getlogjoint(t) for t in transitions] # extract the logjoints
593594
2-element Array{Float64,1}:
594595
-3.6294991938628374
595596
-2.5697948166987845

src/mcmc/ess.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ function DynamicPPL.tilde_assume(
114114
return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi)
115115
end
116116

117-
function DynamicPPL.tilde_observe!!(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi)
117+
function DynamicPPL.tilde_observe!!(
118+
ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi
119+
)
118120
return DynamicPPL.tilde_observe!!(ctx, SampleFromPrior(), right, left, vn, vi)
119121
end

src/mcmc/hmc.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ function DynamicPPL.initialstep(
199199
end
200200

201201
# Cache current log density.
202-
log_density_old = getlogp(vi)
202+
log_density_old = getloglikelihood(vi)
203203

204204
# Find good eps if not provided one
205205
if iszero(spl.alg.ϵ)
@@ -227,10 +227,12 @@ function DynamicPPL.initialstep(
227227
# Update `vi` based on acceptance
228228
if t.stat.is_accept
229229
vi = DynamicPPL.unflatten(vi, t.z.θ)
230-
vi = setlogp!!(vi, t.stat.log_density)
230+
# TODO(mhauru) Is setloglikelihood! the right thing here?
231+
vi = setloglikelihood!!(vi, t.stat.log_density)
231232
else
232233
vi = DynamicPPL.unflatten(vi, theta)
233-
vi = setlogp!!(vi, log_density_old)
234+
# TODO(mhauru) Is setloglikelihood! the right thing here?
235+
vi = setloglikelihood!!(vi, log_density_old)
234236
end
235237

236238
transition = Transition(model, vi, t)
@@ -275,7 +277,8 @@ function AbstractMCMC.step(
275277
vi = state.vi
276278
if t.stat.is_accept
277279
vi = DynamicPPL.unflatten(vi, t.z.θ)
278-
vi = setlogp!!(vi, t.stat.log_density)
280+
# TODO(mhauru) Is setloglikelihood! the right thing here?
281+
vi = setloglikelihood!!(vi, t.stat.log_density)
279282
end
280283

281284
# Compute next transition and state.

src/mcmc/particle_mcmc.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,10 @@ function DynamicPPL.initialstep(
118118
kwargs...,
119119
)
120120
# Reset the VarInfo.
121-
reset_num_produce!(vi)
121+
vi = reset_num_produce!!(vi)
122122
set_retained_vns_del!(vi)
123-
resetlogp!!(vi)
124-
empty!!(vi)
123+
vi = resetlogp!!(vi)
124+
vi = empty!!(vi)
125125

126126
# Create a new set of particles.
127127
particles = AdvancedPS.ParticleContainer(
@@ -252,9 +252,9 @@ function DynamicPPL.initialstep(
252252
kwargs...,
253253
)
254254
# Reset the VarInfo before new sweep
255-
reset_num_produce!(vi)
255+
vi = reset_num_produce!(vi)
256256
set_retained_vns_del!(vi)
257-
resetlogp!!(vi)
257+
vi = resetlogp!!(vi)
258258

259259
# Create a new set of particles
260260
num_particles = spl.alg.nparticles
@@ -284,8 +284,8 @@ function AbstractMCMC.step(
284284
)
285285
# Reset the VarInfo before new sweep.
286286
vi = state.vi
287-
reset_num_produce!(vi)
288-
resetlogp!!(vi)
287+
vi = reset_num_produce!(vi)
288+
vi = resetlogp!!(vi)
289289

290290
# Create reference particle for which the samples will be retained.
291291
reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng))
@@ -408,7 +408,7 @@ function AdvancedPS.Trace(
408408
rng::AdvancedPS.TracedRNG,
409409
)
410410
newvarinfo = deepcopy(varinfo)
411-
DynamicPPL.reset_num_produce!(newvarinfo)
411+
newvarinfo = DynamicPPL.reset_num_produce!!(newvarinfo)
412412

413413
tmodel = Turing.Essential.TracedModel(model, sampler, newvarinfo, rng)
414414
newtrace = AdvancedPS.Trace(tmodel, rng)

0 commit comments

Comments
 (0)