Skip to content

Commit 58cef90

Browse files
committed
Progress towards compat with DPPL v0.35
1 parent 08e4c08 commit 58cef90

26 files changed

+83
-309
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ Distributions = "0.23.3, 0.24, 0.25"
6363
DistributionsAD = "0.6"
6464
DocStringExtensions = "0.8, 0.9"
6565
DynamicHMC = "3.4"
66-
DynamicPPL = "0.34.1"
66+
DynamicPPL = "0.35"
6767
EllipticalSliceSampling = "0.5, 1, 2"
6868
ForwardDiff = "0.10.3"
6969
Libtask = "0.8.8"

ext/TuringDynamicHMCExt.jl

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,15 @@ To use it, make sure you have DynamicHMC package (version >= 2) loaded:
2525
using DynamicHMC
2626
```
2727
"""
28-
struct DynamicNUTS{AD,space,T<:DynamicHMC.NUTS} <: Turing.Inference.Hamiltonian
28+
struct DynamicNUTS{AD,T<:DynamicHMC.NUTS} <: Turing.Inference.Hamiltonian
2929
sampler::T
3030
adtype::AD
3131
end
3232

33-
function DynamicNUTS(
34-
spl::DynamicHMC.NUTS=DynamicHMC.NUTS(),
35-
space::Tuple=();
36-
adtype::ADTypes.AbstractADType=Turing.DEFAULT_ADTYPE,
37-
)
38-
return DynamicNUTS{typeof(adtype),space,typeof(spl)}(spl, adtype)
39-
end
33+
DynamicNUTS() = DynamicNUTS(DynamicHMC.NUTS())
34+
DynamicNUTS(spl) = DynamicNUTS(spl, Turing.DEFAULT_ADTYPE)
4035
Turing.externalsampler(spl::DynamicHMC.NUTS) = DynamicNUTS(spl)
4136

42-
DynamicPPL.getspace(::DynamicNUTS{<:Any,space}) where {space} = space
43-
4437
"""
4538
DynamicNUTSState
4639
@@ -70,8 +63,8 @@ function DynamicPPL.initialstep(
7063
kwargs...,
7164
)
7265
# Ensure that initial sample is in unconstrained space.
73-
if !DynamicPPL.islinked(vi, spl)
74-
vi = DynamicPPL.link!!(vi, spl, model)
66+
if !DynamicPPL.islinked(vi)
67+
vi = DynamicPPL.link!!(vi, model)
7568
vi = last(DynamicPPL.evaluate!!(model, vi, DynamicPPL.SamplingContext(rng, spl)))
7669
end
7770

@@ -82,13 +75,13 @@ function DynamicPPL.initialstep(
8275

8376
# Perform initial step.
8477
results = DynamicHMC.mcmc_keep_warmup(
85-
rng, ℓ, 0; initialization=(q=vi[spl],), reporter=DynamicHMC.NoProgressReport()
78+
rng, ℓ, 0; initialization=(q=vi[:],), reporter=DynamicHMC.NoProgressReport()
8679
)
8780
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
8881
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)
8982

9083
# Update the variables.
91-
vi = DynamicPPL.setindex!!(vi, Q.q, spl)
84+
vi = DynamicPPL.unflatten(vi, Q.q)
9285
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)
9386

9487
# Create first sample and state.
@@ -112,7 +105,7 @@ function AbstractMCMC.step(
112105
Q, _ = DynamicHMC.mcmc_next_step(steps, state.cache)
113106

114107
# Update the variables.
115-
vi = DynamicPPL.setindex!!(vi, Q.q, spl)
108+
vi = DynamicPPL.unflatten(vi, Q.q)
116109
vi = DynamicPPL.setlogp!!(vi, Q.ℓq)
117110

118111
# Create next sample and state.

src/Turing.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using Compat: pkgversion
99

1010
using AdvancedVI: AdvancedVI
1111
using DynamicPPL: DynamicPPL, LogDensityFunction
12-
import DynamicPPL: getspace, NoDist, NamedDist
12+
import DynamicPPL: NoDist, NamedDist
1313
using LogDensityProblems: LogDensityProblems
1414
using NamedArrays: NamedArrays
1515
using Accessors: Accessors

src/essential/container.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ function AdvancedPS.advance!(
3939
end
4040

4141
function AdvancedPS.delete_retained!(trace::TracedModel)
42-
DynamicPPL.set_retained_vns_del_by_spl!(trace.varinfo, trace.sampler)
42+
DynamicPPL.set_retained_vns_del!(trace.varinfo)
4343
return trace
4444
end
4545

src/mcmc/Inference.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@ using DynamicPPL:
55
Metadata,
66
VarInfo,
77
TypedVarInfo,
8+
# TODO(mhauru) all_varnames_grouped_by_symbol isn't exported by DPPL. Either export it
9+
# or use something else.
10+
all_varnames_grouped_by_symbol,
811
islinked,
912
setindex!!,
1013
push!!,
1114
setlogp!!,
1215
getlogp,
1316
VarName,
1417
getsym,
15-
_getvns,
1618
getdist,
1719
Model,
1820
Sampler,
@@ -22,9 +24,7 @@ using DynamicPPL:
2224
PriorContext,
2325
LikelihoodContext,
2426
set_flag!,
25-
unset_flag!,
26-
getspace,
27-
inspace
27+
unset_flag!
2828
using Distributions, Libtask, Bijectors
2929
using DistributionsAD: VectorOfMultivariate
3030
using LinearAlgebra
@@ -75,9 +75,7 @@ export InferenceAlgorithm,
7575
RepeatSampler,
7676
Prior,
7777
assume,
78-
dot_assume,
7978
observe,
80-
dot_observe,
8179
predict,
8280
externalsampler
8381

@@ -299,7 +297,7 @@ function AbstractMCMC.sample(
299297
kwargs...,
300298
)
301299
check_model && _check_model(model, alg)
302-
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; kwargs...)
300+
return AbstractMCMC.sample(rng, model, Sampler(alg), N; kwargs...)
303301
end
304302

305303
function AbstractMCMC.sample(
@@ -326,9 +324,7 @@ function AbstractMCMC.sample(
326324
kwargs...,
327325
)
328326
check_model && _check_model(model, alg)
329-
return AbstractMCMC.sample(
330-
rng, model, Sampler(alg, model), ensemble, N, n_chains; kwargs...
331-
)
327+
return AbstractMCMC.sample(rng, model, Sampler(alg), ensemble, N, n_chains; kwargs...)
332328
end
333329

334330
function AbstractMCMC.sample(
@@ -583,11 +579,6 @@ end
583579
# Utilities #
584580
##############
585581

586-
# TODO(mhauru) Remove this once DynamicPPL has removed all its Selector stuff.
587-
DynamicPPL.getspace(::InferenceAlgorithm) = ()
588-
DynamicPPL.getspace(spl::Sampler) = getspace(spl.alg)
589-
DynamicPPL.inspace(vn::VarName, spl::Sampler) = inspace(vn, getspace(spl.alg))
590-
591582
"""
592583
593584
transitions_from_chain(

src/mcmc/emcee.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ function AbstractMCMC.step(
6767
state = EmceeState(
6868
vis[1],
6969
map(vis) do vi
70-
vi = DynamicPPL.link!!(vi, spl, model)
70+
vi = DynamicPPL.link!!(vi, model)
7171
AMH.Transition(vi[spl], getlogp(vi), false)
7272
end,
7373
)
@@ -89,7 +89,7 @@ function AbstractMCMC.step(
8989

9090
# Compute the next transition and state.
9191
transition = map(states) do _state
92-
vi = setindex!!(vi, _state.params, spl)
92+
vi = DynamicPPL.unflatten(vi, _state.params)
9393
t = Transition(getparams(model, vi), _state.lp)
9494
return t
9595
end

src/mcmc/ess.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -136,17 +136,3 @@ function DynamicPPL.tilde_observe(
136136
)
137137
return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi)
138138
end
139-
140-
function DynamicPPL.dot_tilde_assume(
141-
rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, left, vns, vi
142-
)
143-
return DynamicPPL.dot_tilde_assume(
144-
rng, LikelihoodContext(), SampleFromPrior(), right, left, vns, vi
145-
)
146-
end
147-
148-
function DynamicPPL.dot_tilde_observe(
149-
ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi
150-
)
151-
return DynamicPPL.dot_tilde_observe(ctx, SampleFromPrior(), right, left, vi)
152-
end

src/mcmc/gibbs.jl

Lines changed: 3 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -198,58 +198,6 @@ function DynamicPPL.tilde_assume(
198198
end
199199
end
200200

201-
# Like the above tilde_assume methods, but with dot_tilde_assume and broadcasting of logpdf.
202-
# See comments there for more details.
203-
function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi)
204-
child_context = DynamicPPL.childcontext(context)
205-
return if is_target_varname(context, vns)
206-
DynamicPPL.dot_tilde_assume(child_context, right, left, vns, vi)
207-
elseif has_conditioned_gibbs(context, vns)
208-
value, lp, _ = DynamicPPL.dot_tilde_assume(
209-
child_context, right, left, vns, get_global_varinfo(context)
210-
)
211-
value, lp, vi
212-
else
213-
value, lp, new_global_vi = DynamicPPL.dot_tilde_assume(
214-
child_context,
215-
DynamicPPL.SampleFromPrior(),
216-
right,
217-
left,
218-
vns,
219-
get_global_varinfo(context),
220-
)
221-
set_global_varinfo!(context, new_global_vi)
222-
value, lp, vi
223-
end
224-
end
225-
226-
# As above but with an RNG.
227-
function DynamicPPL.dot_tilde_assume(
228-
rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi
229-
)
230-
child_context = DynamicPPL.childcontext(context)
231-
return if is_target_varname(context, vns)
232-
DynamicPPL.dot_tilde_assume(rng, child_context, sampler, right, left, vns, vi)
233-
elseif has_conditioned_gibbs(context, vns)
234-
value, lp, _ = DynamicPPL.dot_tilde_assume(
235-
child_context, right, left, vns, get_global_varinfo(context)
236-
)
237-
value, lp, vi
238-
else
239-
value, lp, new_global_vi = DynamicPPL.dot_tilde_assume(
240-
rng,
241-
child_context,
242-
DynamicPPL.SampleFromPrior(),
243-
right,
244-
left,
245-
vns,
246-
get_global_varinfo(context),
247-
)
248-
set_global_varinfo!(context, new_global_vi)
249-
value, lp, vi
250-
end
251-
end
252-
253201
"""
254202
make_conditional(model, target_variables, varinfo)
255203
@@ -281,16 +229,8 @@ function make_conditional(
281229
return DynamicPPL.contextualize(model, gibbs_context), gibbs_context_inner
282230
end
283231

284-
# All samplers are given the same Selector, so that they will sample all variables
285-
# given to them by the Gibbs sampler. This avoids conflicts between the new and the old way
286-
# of choosing which sampler to use.
287-
function set_selector(x::DynamicPPL.Sampler)
288-
return DynamicPPL.Sampler(x.alg, DynamicPPL.Selector(0))
289-
end
290-
function set_selector(x::RepeatSampler)
291-
return RepeatSampler(set_selector(x.sampler), x.num_repeat)
292-
end
293-
set_selector(x::InferenceAlgorithm) = DynamicPPL.Sampler(x, DynamicPPL.Selector(0))
232+
wrap_in_sampler(x::AbstractMCMC.AbstractSampler) = x
233+
wrap_in_sampler(x::InferenceAlgorithm) = DynamicPPL.Sampler(x)
294234

295235
to_varname_list(x::Union{VarName,Symbol}) = [VarName(x)]
296236
# Any other value is assumed to be an iterable of VarNames and Symbols.
@@ -343,9 +283,7 @@ struct Gibbs{N,V<:NTuple{N,AbstractVector{<:VarName}},A<:NTuple{N,Any}} <:
343283
end
344284
end
345285

346-
# Ensure that samplers have the same selector, and that varnames are lists of
347-
# VarNames.
348-
samplers = tuple(map(set_selector, samplers)...)
286+
samplers = tuple(map(wrap_in_sampler, samplers)...)
349287
varnames = tuple(map(to_varname_list, varnames)...)
350288
return new{length(samplers),typeof(varnames),typeof(samplers)}(varnames, samplers)
351289
end

src/mcmc/hmc.jl

Lines changed: 8 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,10 @@ function DynamicPPL.initialstep(
148148
kwargs...,
149149
)
150150
# Transform the samples to unconstrained space and compute the joint log probability.
151-
vi = DynamicPPL.link(vi_original, spl, model)
151+
vi = DynamicPPL.link(vi_original, model)
152152

153153
# Extract parameters.
154-
theta = vi[spl]
154+
theta = vi[:]
155155

156156
# Create a Hamiltonian.
157157
metricT = getmetricT(spl.alg)
@@ -189,7 +189,7 @@ function DynamicPPL.initialstep(
189189

190190
# NOTE: This will sample in the unconstrained space.
191191
vi = last(DynamicPPL.evaluate!!(model, rng, vi, SampleFromUniform()))
192-
theta = vi[spl]
192+
theta = vi[:]
193193

194194
hamiltonian = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ)
195195
z = AHMC.phasepoint(rng, theta, hamiltonian)
@@ -226,10 +226,10 @@ function DynamicPPL.initialstep(
226226

227227
# Update `vi` based on acceptance
228228
if t.stat.is_accept
229-
vi = DynamicPPL.unflatten(vi, spl, t.z.θ)
229+
vi = DynamicPPL.unflatten(vi, t.z.θ)
230230
vi = setlogp!!(vi, t.stat.log_density)
231231
else
232-
vi = DynamicPPL.unflatten(vi, spl, theta)
232+
vi = DynamicPPL.unflatten(vi, theta)
233233
vi = setlogp!!(vi, log_density_old)
234234
end
235235

@@ -274,7 +274,7 @@ function AbstractMCMC.step(
274274
# Update variables
275275
vi = state.vi
276276
if t.stat.is_accept
277-
vi = DynamicPPL.unflatten(vi, spl, t.z.θ)
277+
vi = DynamicPPL.unflatten(vi, t.z.θ)
278278
vi = setlogp!!(vi, t.stat.log_density)
279279
end
280280

@@ -493,45 +493,15 @@ end
493493
#### Compiler interface, i.e. tilde operators.
494494
####
495495
function DynamicPPL.assume(
496-
rng, spl::Sampler{<:Hamiltonian}, dist::Distribution, vn::VarName, vi
496+
rng, ::Sampler{<:Hamiltonian}, dist::Distribution, vn::VarName, vi
497497
)
498498
return DynamicPPL.assume(dist, vn, vi)
499499
end
500500

501-
function DynamicPPL.dot_assume(
502-
rng,
503-
spl::Sampler{<:Hamiltonian},
504-
dist::MultivariateDistribution,
505-
vns::AbstractArray{<:VarName},
506-
var::AbstractMatrix,
507-
vi,
508-
)
509-
return DynamicPPL.dot_assume(dist, var, vns, vi)
510-
end
511-
function DynamicPPL.dot_assume(
512-
rng,
513-
spl::Sampler{<:Hamiltonian},
514-
dists::Union{Distribution,AbstractArray{<:Distribution}},
515-
vns::AbstractArray{<:VarName},
516-
var::AbstractArray,
517-
vi,
518-
)
519-
return DynamicPPL.dot_assume(dists, var, vns, vi)
520-
end
521-
522-
function DynamicPPL.observe(spl::Sampler{<:Hamiltonian}, d::Distribution, value, vi)
501+
function DynamicPPL.observe(::Sampler{<:Hamiltonian}, d::Distribution, value, vi)
523502
return DynamicPPL.observe(d, value, vi)
524503
end
525504

526-
function DynamicPPL.dot_observe(
527-
spl::Sampler{<:Hamiltonian},
528-
ds::Union{Distribution,AbstractArray{<:Distribution}},
529-
value::AbstractArray,
530-
vi,
531-
)
532-
return DynamicPPL.dot_observe(ds, value, vi)
533-
end
534-
535505
####
536506
#### Default HMC stepsize and mass matrix adaptor
537507
####

0 commit comments

Comments
 (0)