Skip to content

Commit 5d860d9

Browse files
mhaurupenelopeysm
authored andcommitted
More DPPL 0.37 compat work, WIP
1 parent cea1f7d commit 5d860d9

File tree

9 files changed

+229
-91
lines changed

9 files changed

+229
-91
lines changed

ext/TuringOptimExt.jl

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ function Optim.optimize(
3434
options::Optim.Options=Optim.Options();
3535
kwargs...,
3636
)
37-
vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),))
38-
f = Optimisation.OptimLogDensity(model, vi)
37+
f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood)
3938
init_vals = DynamicPPL.getparams(f.ldf)
4039
optimizer = Optim.LBFGS()
4140
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
@@ -57,8 +56,7 @@ function Optim.optimize(
5756
options::Optim.Options=Optim.Options();
5857
kwargs...,
5958
)
60-
vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),))
61-
f = Optimisation.OptimLogDensity(model, vi)
59+
f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood)
6260
init_vals = DynamicPPL.getparams(f.ldf)
6361
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
6462
end
@@ -74,8 +72,7 @@ function Optim.optimize(
7472
end
7573

7674
function _mle_optimize(model::DynamicPPL.Model, args...; kwargs...)
77-
vi = DynamicPPL.setaccs!!(VarInfo(model), (DynamicPPL.LogLikelihoodAccumulator(),))
78-
f = Optimisation.OptimLogDensity(model, vi)
75+
f = Optimisation.OptimLogDensity(model, DynamicPPL.getloglikelihood)
7976
return _optimize(f, args...; kwargs...)
8077
end
8178

@@ -105,8 +102,7 @@ function Optim.optimize(
105102
options::Optim.Options=Optim.Options();
106103
kwargs...,
107104
)
108-
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
109-
f = Optimisation.OptimLogDensity(model, vi)
105+
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
110106
init_vals = DynamicPPL.getparams(f.ldf)
111107
optimizer = Optim.LBFGS()
112108
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
@@ -128,8 +124,7 @@ function Optim.optimize(
128124
options::Optim.Options=Optim.Options();
129125
kwargs...,
130126
)
131-
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
132-
f = Optimisation.OptimLogDensity(model, vi)
127+
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
133128
init_vals = DynamicPPL.getparams(f.ldf)
134129
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
135130
end
@@ -145,8 +140,7 @@ function Optim.optimize(
145140
end
146141

147142
function _map_optimize(model::DynamicPPL.Model, args...; kwargs...)
148-
vi = DynamicPPL.setaccs!!(VarInfo(model), (LogPriorWithoutJacobianAccumulator(), DynamicPPL.LogLikelihoodAccumulator(),))
149-
f = Optimisation.OptimLogDensity(model, vi)
143+
f = Optimisation.OptimLogDensity(model, Optimisation.getlogjoint_without_jacobian)
150144
return _optimize(f, args...; kwargs...)
151145
end
152146

@@ -169,7 +163,9 @@ function _optimize(
169163
# whether initialisation is really necessary at all
170164
vi = DynamicPPL.unflatten(f.ldf.varinfo, init_vals)
171165
vi = DynamicPPL.link(vi, f.ldf.model)
172-
f = Optimisation.OptimLogDensity(f.ldf.model, vi; adtype=f.ldf.adtype)
166+
f = Optimisation.OptimLogDensity(
167+
f.ldf.model, f.ldf.getlogdensity, vi; adtype=f.ldf.adtype
168+
)
173169
init_vals = DynamicPPL.getparams(f.ldf)
174170

175171
# Optimize!
@@ -186,7 +182,9 @@ function _optimize(
186182
# Get the optimum in unconstrained space. `getparams` does the invlinking.
187183
vi = f.ldf.varinfo
188184
vi_optimum = DynamicPPL.unflatten(vi, M.minimizer)
189-
logdensity_optimum = Optimisation.OptimLogDensity(f.ldf.model, vi_optimum; adtype=f.ldf.adtype)
185+
logdensity_optimum = Optimisation.OptimLogDensity(
186+
f.ldf.model, f.ldf.getlogdensity, vi_optimum; adtype=f.ldf.adtype
187+
)
190188
vns_vals_iter = Turing.Inference.getparams(f.ldf.model, vi_optimum)
191189
varnames = map(Symbol first, vns_vals_iter)
192190
vals = map(last, vns_vals_iter)

src/essential/container.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
struct TracedModel{S<:AbstractSampler,V<:AbstractVarInfo,M<:Model,E<:Tuple} <:
2+
AdvancedPS.AbstractGenericModel
3+
model::M
4+
sampler::S
5+
varinfo::V
6+
evaluator::E
7+
end
8+
9+
function TracedModel(
10+
model::Model,
11+
sampler::AbstractSampler,
12+
varinfo::AbstractVarInfo,
13+
rng::Random.AbstractRNG,
14+
)
15+
context = SamplingContext(rng, sampler, DefaultContext())
16+
args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(model, varinfo, context)
17+
if kwargs !== nothing && !isempty(kwargs)
18+
error(
19+
"Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.",
20+
)
21+
end
22+
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
23+
model, sampler, varinfo, (model.f, args...)
24+
)
25+
end
26+
27+
function AdvancedPS.advance!(
28+
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}, isref::Bool=false
29+
)
30+
# Make sure we load/reset the rng in the new replaying mechanism
31+
# TODO(mhauru) Stop ignoring the return value.
32+
DynamicPPL.increment_num_produce!!(trace.model.f.varinfo)
33+
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
34+
score = consume(trace.model.ctask)
35+
if score === nothing
36+
return nothing
37+
else
38+
return score + DynamicPPL.getlogp(trace.model.f.varinfo)
39+
end
40+
end
41+
42+
function AdvancedPS.delete_retained!(trace::TracedModel)
43+
DynamicPPL.set_retained_vns_del!(trace.varinfo)
44+
return trace
45+
end
46+
47+
function AdvancedPS.reset_model(trace::TracedModel)
48+
new_vi = DynamicPPL.reset_num_produce!!(trace.varinfo)
49+
trace = TracedModel(trace.model, trace.sampler, new_vi, trace.evaluator)
50+
return trace
51+
end
52+
53+
function AdvancedPS.reset_logprob!(trace::TracedModel)
54+
# TODO(mhauru) Stop ignoring the return value.
55+
DynamicPPL.resetlogp!!(trace.model.varinfo)
56+
return trace
57+
end
58+
59+
function AdvancedPS.update_rng!(
60+
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}
61+
)
62+
# Extract the `args`.
63+
args = trace.model.ctask.args
64+
# From `args`, extract the `SamplingContext`, which contains the RNG.
65+
sampling_context = args[3]
66+
rng = sampling_context.rng
67+
trace.rng = rng
68+
return trace
69+
end
70+
71+
function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG, args...; kwargs...) # RNG ?
72+
return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...)
73+
end

src/mcmc/Inference.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ using DynamicPPL:
1818
push!!,
1919
setlogp!!,
2020
getlogp,
21+
getlogjoint,
2122
VarName,
2223
getsym,
2324
getdist,
@@ -136,7 +137,7 @@ end
136137
Transition(θ, lp) = Transition(θ, lp, nothing)
137138
function Transition(model::DynamicPPL.Model, vi::AbstractVarInfo, t)
138139
θ = getparams(model, vi)
139-
lp = getlogp(vi)
140+
lp = getlogjoint(vi)
140141
return Transition(θ, lp, getstats(t))
141142
end
142143

@@ -149,10 +150,10 @@ function metadata(t::Transition)
149150
end
150151
end
151152

152-
DynamicPPL.getlogp(t::Transition) = t.lp
153+
DynamicPPL.getlogjoint(t::Transition) = t.lp
153154

154155
# Metadata of VarInfo object
155-
metadata(vi::AbstractVarInfo) = (lp=getlogp(vi),)
156+
metadata(vi::AbstractVarInfo) = (lp=getlogjoint(vi),)
156157

157158
##########################
158159
# Chain making utilities #
@@ -215,7 +216,7 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
215216
end
216217

217218
function get_transition_extras(ts::AbstractVector{<:VarInfo})
218-
valmat = reshape([getlogp(t) for t in ts], :, 1)
219+
valmat = reshape([getlogjoint(t) for t in ts], :, 1)
219220
return [:lp], valmat
220221
end
221222

@@ -434,7 +435,7 @@ julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m`
434435
435436
julia> transitions = Turing.Inference.transitions_from_chain(m, chain);
436437
437-
julia> [Turing.Inference.getlogp(t) for t in transitions] # extract the logjoints
438+
julia> [Turing.Inference.getlogjoint(t) for t in transitions] # extract the logjoints
438439
2-element Array{Float64,1}:
439440
-3.6294991938628374
440441
-2.5697948166987845

src/mcmc/ess.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ function DynamicPPL.tilde_assume(
114114
return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, vi)
115115
end
116116

117-
function DynamicPPL.tilde_observe!!(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi)
117+
function DynamicPPL.tilde_observe!!(
118+
ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vn, vi
119+
)
118120
return DynamicPPL.tilde_observe!!(ctx, SampleFromPrior(), right, left, vn, vi)
119121
end

src/mcmc/hmc.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ function DynamicPPL.initialstep(
214214
theta = vi[:]
215215

216216
# Cache current log density.
217-
log_density_old = getlogp(vi)
217+
log_density_old = getloglikelihood(vi)
218218

219219
# Find good eps if not provided one
220220
if iszero(spl.alg.ϵ)
@@ -242,10 +242,12 @@ function DynamicPPL.initialstep(
242242
# Update `vi` based on acceptance
243243
if t.stat.is_accept
244244
vi = DynamicPPL.unflatten(vi, t.z.θ)
245-
vi = setlogp!!(vi, t.stat.log_density)
245+
# TODO(mhauru) Is setloglikelihood! the right thing here?
246+
vi = setloglikelihood!!(vi, t.stat.log_density)
246247
else
247248
vi = DynamicPPL.unflatten(vi, theta)
248-
vi = setlogp!!(vi, log_density_old)
249+
# TODO(mhauru) Is setloglikelihood! the right thing here?
250+
vi = setloglikelihood!!(vi, log_density_old)
249251
end
250252

251253
transition = Transition(model, vi, t)
@@ -290,7 +292,8 @@ function AbstractMCMC.step(
290292
vi = state.vi
291293
if t.stat.is_accept
292294
vi = DynamicPPL.unflatten(vi, t.z.θ)
293-
vi = setlogp!!(vi, t.stat.log_density)
295+
# TODO(mhauru) Is setloglikelihood! the right thing here?
296+
vi = setloglikelihood!!(vi, t.stat.log_density)
294297
end
295298

296299
# Compute next transition and state.

src/mcmc/particle_mcmc.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,10 +193,10 @@ function DynamicPPL.initialstep(
193193
kwargs...,
194194
)
195195
# Reset the VarInfo.
196-
DynamicPPL.reset_num_produce!(vi)
197-
DynamicPPL.set_retained_vns_del!(vi)
198-
DynamicPPL.resetlogp!!(vi)
199-
DynamicPPL.empty!!(vi)
196+
vi = DynamicPPL.reset_num_produce!!(vi)
197+
set_retained_vns_del!(vi)
198+
vi = DynamicPPL.resetlogp!!(vi)
199+
vi = DynamicPPL.empty!!(vi)
200200

201201
# Create a new set of particles.
202202
particles = AdvancedPS.ParticleContainer(
@@ -327,9 +327,9 @@ function DynamicPPL.initialstep(
327327
kwargs...,
328328
)
329329
# Reset the VarInfo before new sweep
330-
DynamicPPL.reset_num_produce!(vi)
330+
vi = DynamicPPL.reset_num_produce!!(vi)
331331
DynamicPPL.set_retained_vns_del!(vi)
332-
DynamicPPL.resetlogp!!(vi)
332+
vi = DynamicPPL.resetlogp!!(vi)
333333

334334
# Create a new set of particles
335335
num_particles = spl.alg.nparticles
@@ -359,8 +359,8 @@ function AbstractMCMC.step(
359359
)
360360
# Reset the VarInfo before new sweep.
361361
vi = state.vi
362-
DynamicPPL.reset_num_produce!(vi)
363-
DynamicPPL.resetlogp!!(vi)
362+
vi = DynamicPPL.reset_num_produce!!(vi)
363+
vi = DynamicPPL.resetlogp!!(vi)
364364

365365
# Create reference particle for which the samples will be retained.
366366
reference = AdvancedPS.forkr(AdvancedPS.Trace(model, spl, vi, state.rng))
@@ -479,7 +479,7 @@ function AdvancedPS.Trace(
479479
rng::AdvancedPS.TracedRNG,
480480
)
481481
newvarinfo = deepcopy(varinfo)
482-
DynamicPPL.reset_num_produce!(newvarinfo)
482+
newvarinfo = DynamicPPL.reset_num_produce!!(newvarinfo)
483483

484484
tmodel = TracedModel(model, sampler, newvarinfo, rng)
485485
newtrace = AdvancedPS.Trace(tmodel, rng)

0 commit comments

Comments
 (0)