Skip to content

Commit 6f11379

Browse files
committed
Use FastLDF and removal of NodeTrait
1 parent 4153a83 commit 6f11379

File tree

10 files changed

+39
-30
lines changed

10 files changed

+39
-30
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Distributions = "0.25.77"
6464
DistributionsAD = "0.6"
6565
DocStringExtensions = "0.8, 0.9"
6666
DynamicHMC = "3.4"
67-
DynamicPPL = "0.38"
67+
DynamicPPL = "0.39"
6868
EllipticalSliceSampling = "0.5, 1, 2"
6969
ForwardDiff = "0.10.3, 1"
7070
Libtask = "0.9.3"
@@ -90,3 +90,6 @@ julia = "1.10.8"
9090
[extras]
9191
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
9292
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
93+
94+
[sources]
95+
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "py/ldf"}

src/mcmc/gibbs.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@ isgibbscomponent(::AdvancedHMC.AbstractHMCSampler) = true
2121
isgibbscomponent(::AdvancedMH.MetropolisHastings) = true
2222
isgibbscomponent(spl) = false
2323

24-
function can_be_wrapped(ctx::DynamicPPL.AbstractContext)
25-
return DynamicPPL.NodeTrait(ctx) isa DynamicPPL.IsLeaf
26-
end
27-
can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context)
24+
can_be_wrapped(::DynamicPPL.AbstractContext) = true
25+
can_be_wrapped(::DynamicPPL.AbstractParentContext) = false
26+
can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(DynamicPPL.childcontext(ctx))
2827

2928
# Basically like a `DynamicPPL.FixedContext` but
3029
# 1. Hijacks the tilde pipeline to fix variables.
@@ -55,7 +54,7 @@ $(FIELDS)
5554
"""
5655
struct GibbsContext{
5756
VNs<:Tuple{Vararg{VarName}},GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext
58-
} <: DynamicPPL.AbstractContext
57+
} <: DynamicPPL.AbstractParentContext
5958
"""
6059
the VarNames being sampled
6160
"""
@@ -86,7 +85,6 @@ function GibbsContext(target_varnames, global_varinfo)
8685
return GibbsContext(target_varnames, global_varinfo, DynamicPPL.DefaultContext())
8786
end
8887

89-
DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent()
9088
DynamicPPL.childcontext(context::GibbsContext) = context.context
9189
function DynamicPPL.setchildcontext(context::GibbsContext, childcontext)
9290
return GibbsContext(

src/mcmc/hmc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ function Turing.Inference.initialstep(
196196
# Create a Hamiltonian.
197197
metricT = getmetricT(spl)
198198
metric = metricT(length(theta))
199-
ldf = DynamicPPL.LogDensityFunction(
199+
ldf = DynamicPPL.Experimental.FastLDF(
200200
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype
201201
)
202202
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
@@ -278,7 +278,7 @@ end
278278

279279
function get_hamiltonian(model, spl, vi, state, n)
280280
metric = gen_metric(n, spl, state)
281-
ldf = DynamicPPL.LogDensityFunction(
281+
ldf = DynamicPPL.Experimental.FastLDF(
282282
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype
283283
)
284284
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)

src/mcmc/is.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ end
4949
struct ISContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext
5050
rng::R
5151
end
52-
DynamicPPL.NodeTrait(::ISContext) = DynamicPPL.IsLeaf()
5352

5453
function DynamicPPL.tilde_assume!!(
5554
ctx::ISContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo

src/mcmc/mh.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,6 @@ end
410410
struct MHContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext
411411
rng::R
412412
end
413-
DynamicPPL.NodeTrait(::MHContext) = DynamicPPL.IsLeaf()
414413

415414
function DynamicPPL.tilde_assume!!(
416415
context::MHContext, right::Distribution, vn::VarName, vi::AbstractVarInfo

src/mcmc/particle_mcmc.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
struct ParticleMCMCContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext
88
rng::R
99
end
10-
DynamicPPL.NodeTrait(::ParticleMCMCContext) = DynamicPPL.IsLeaf()
1110

1211
struct TracedModel{V<:AbstractVarInfo,M<:Model,E<:Tuple} <: AdvancedPS.AbstractGenericModel
1312
model::M

src/mcmc/prior.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,30 @@ Algorithm for sampling from the prior.
55
"""
66
struct Prior <: AbstractSampler end
77

8+
function AbstractMCMC.step(
9+
rng::Random.AbstractRNG, model::DynamicPPL.Model, sampler::Prior; kwargs...
10+
)
11+
accs = DynamicPPL.AccumulatorTuple((
12+
DynamicPPL.ValuesAsInModelAccumulator(true),
13+
DynamicPPL.LogPriorAccumulator(),
14+
DynamicPPL.LogLikelihoodAccumulator(),
15+
))
16+
sampling_model = DynamicPPL.setleafcontext(
17+
model, DynamicPPL.InitContext(rng, DynamicPPL.InitFromPrior())
18+
)
19+
vi = DynamicPPL.OnlyAccsVarInfo(accs)
20+
_, vi = DynamicPPL.evaluate!!(sampling_model, vi)
21+
return Transition(sampling_model, vi, nothing; reevaluate=false), (sampling_model, vi)
22+
end
23+
824
function AbstractMCMC.step(
925
rng::Random.AbstractRNG,
1026
model::DynamicPPL.Model,
1127
sampler::Prior,
12-
state=nothing;
28+
state::Tuple{DynamicPPL.Model,DynamicPPL.Experimental.OnlyAccsVarInfo};
1329
kwargs...,
1430
)
15-
vi = DynamicPPL.setaccs!!(
16-
DynamicPPL.VarInfo(),
17-
(
18-
DynamicPPL.ValuesAsInModelAccumulator(true),
19-
DynamicPPL.LogPriorAccumulator(),
20-
DynamicPPL.LogLikelihoodAccumulator(),
21-
),
22-
)
23-
_, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromPrior())
24-
return Transition(model, vi, nothing; reevaluate=false), nothing
31+
model, vi = state
32+
_, vi = DynamicPPL.evaluate!!(model, vi)
33+
return Transition(model, vi, nothing; reevaluate=false), (model, vi)
2534
end

test/Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Combinatorics = "1"
5353
Distributions = "0.25"
5454
DistributionsAD = "0.6.3"
5555
DynamicHMC = "2.1.6, 3.0"
56-
DynamicPPL = "0.38"
56+
DynamicPPL = "0.39"
5757
FiniteDifferences = "0.10.8, 0.11, 0.12"
5858
ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1"
5959
HypothesisTests = "0.11"
@@ -77,3 +77,6 @@ StatsBase = "0.33, 0.34"
7777
StatsFuns = "0.9.5, 1"
7878
TimerOutputs = "0.5"
7979
julia = "1.10"
80+
81+
[sources]
82+
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "py/ldf"}

test/ad.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ encountered.
9494
9595
"""
9696
struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <:
97-
DynamicPPL.AbstractContext
97+
DynamicPPL.AbstractParentContext
9898
child::ChildContext
9999

100100
function ADTypeCheckContext(adbackend, child)
@@ -108,7 +108,6 @@ end
108108

109109
adtype(_::ADTypeCheckContext{ADType}) where {ADType} = ADType
110110

111-
DynamicPPL.NodeTrait(::ADTypeCheckContext) = DynamicPPL.IsParent()
112111
DynamicPPL.childcontext(c::ADTypeCheckContext) = c.child
113112
function DynamicPPL.setchildcontext(c::ADTypeCheckContext, child)
114113
return ADTypeCheckContext(adtype(c), child)

test/mcmc/gibbs.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,15 +159,15 @@ end
159159
# It is modified by the capture_targets_and_algs function.
160160
targets_and_algs = Any[]
161161

162-
function capture_targets_and_algs(sampler, context)
163-
if DynamicPPL.NodeTrait(context) == DynamicPPL.IsLeaf()
164-
return nothing
165-
end
162+
function capture_targets_and_algs(sampler, context::DynamicPPL.AbstractParentContext)
166163
if context isa Inference.GibbsContext
167164
push!(targets_and_algs, (context.target_varnames, sampler))
168165
end
169166
return capture_targets_and_algs(sampler, DynamicPPL.childcontext(context))
170167
end
168+
function capture_targets_and_algs(sampler, ::DynamicPPL.AbstractContext)
169+
return nothing # Leaf context.
170+
end
171171

172172
# The methods that capture testing information for us.
173173
function AbstractMCMC.step(

0 commit comments

Comments
 (0)