diff --git a/Project.toml b/Project.toml index cb7b1cb72..aed8d8244 100644 --- a/Project.toml +++ b/Project.toml @@ -64,7 +64,7 @@ Distributions = "0.25.77" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.38" +DynamicPPL = "0.39" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3, 1" Libtask = "0.9.3" @@ -90,3 +90,6 @@ julia = "1.10.8" [extras] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" Optim = "429524aa-4258-5aef-a3af-852621145aeb" + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "py/ldf"} diff --git a/src/mcmc/gibbs.jl b/src/mcmc/gibbs.jl index 7d15829a3..a5e3c0b89 100644 --- a/src/mcmc/gibbs.jl +++ b/src/mcmc/gibbs.jl @@ -21,10 +21,9 @@ isgibbscomponent(::AdvancedHMC.AbstractHMCSampler) = true isgibbscomponent(::AdvancedMH.MetropolisHastings) = true isgibbscomponent(spl) = false -function can_be_wrapped(ctx::DynamicPPL.AbstractContext) - return DynamicPPL.NodeTrait(ctx) isa DynamicPPL.IsLeaf -end -can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context) +can_be_wrapped(::DynamicPPL.AbstractContext) = true +can_be_wrapped(::DynamicPPL.AbstractParentContext) = false +can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(DynamicPPL.childcontext(ctx)) # Basically like a `DynamicPPL.FixedContext` but # 1. Hijacks the tilde pipeline to fix variables. @@ -55,7 +54,7 @@ $(FIELDS) """ struct GibbsContext{ VNs<:Tuple{Vararg{VarName}},GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext -} <: DynamicPPL.AbstractContext +} <: DynamicPPL.AbstractParentContext """ the VarNames being sampled """ @@ -86,7 +85,6 @@ function GibbsContext(target_varnames, global_varinfo) return GibbsContext(target_varnames, global_varinfo, DynamicPPL.DefaultContext()) end -DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(context::GibbsContext) = context.context function DynamicPPL.setchildcontext(context::GibbsContext, childcontext) return GibbsContext( diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index 101847b75..bf566b485 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -12,6 +12,7 @@ struct HMCState{ THam<:AHMC.Hamiltonian, PhType<:AHMC.PhasePoint, TAdapt<:AHMC.Adaptation.AbstractAdaptor, + L<:DynamicPPL.Experimental.FastLDF, } vi::TV i::Int @@ -19,6 +20,7 @@ struct HMCState{ hamiltonian::THam z::PhType adaptor::TAdapt + ldf::L end ### @@ -196,7 +198,7 @@ function Turing.Inference.initialstep( # Create a Hamiltonian. metricT = getmetricT(spl) metric = metricT(length(theta)) - ldf = DynamicPPL.LogDensityFunction( + ldf = DynamicPPL.Experimental.FastLDF( model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) @@ -225,8 +227,8 @@ function Turing.Inference.initialstep( kernel = make_ahmc_kernel(spl, ϵ) adaptor = AHMCAdaptor(spl, hamiltonian.metric; ϵ=ϵ) - transition = Transition(model, vi, NamedTuple()) - state = HMCState(vi, 1, kernel, hamiltonian, z, adaptor) + transition = DynamicPPL.ParamsWithStats(theta, ldf, NamedTuple()) + state = HMCState(vi, 1, kernel, hamiltonian, z, adaptor, ldf) return transition, state end @@ -270,15 +272,15 @@ function AbstractMCMC.step( end # Compute next transition and state. - transition = Transition(model, vi, t) - newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor) + transition = DynamicPPL.ParamsWithStats(t.z.θ, state.ldf, t.stat) + newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor, state.ldf) return transition, newstate end function get_hamiltonian(model, spl, vi, state, n) metric = gen_metric(n, spl, state) - ldf = DynamicPPL.LogDensityFunction( + ldf = DynamicPPL.Experimental.FastLDF( model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype ) lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf) diff --git a/src/mcmc/is.jl b/src/mcmc/is.jl index 88f915d1f..9f6918046 100644 --- a/src/mcmc/is.jl +++ b/src/mcmc/is.jl @@ -49,7 +49,6 @@ end struct ISContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext rng::R end -DynamicPPL.NodeTrait(::ISContext) = DynamicPPL.IsLeaf() function DynamicPPL.tilde_assume!!( ctx::ISContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo diff --git a/src/mcmc/mh.jl b/src/mcmc/mh.jl index 833303b86..0e7b5d2f7 100644 --- a/src/mcmc/mh.jl +++ b/src/mcmc/mh.jl @@ -410,7 +410,6 @@ end struct MHContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext rng::R end -DynamicPPL.NodeTrait(::MHContext) = DynamicPPL.IsLeaf() function DynamicPPL.tilde_assume!!( context::MHContext, right::Distribution, vn::VarName, vi::AbstractVarInfo diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 7aadef09e..387e2c529 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -7,7 +7,6 @@ struct ParticleMCMCContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext rng::R end -DynamicPPL.NodeTrait(::ParticleMCMCContext) = DynamicPPL.IsLeaf() struct TracedModel{V<:AbstractVarInfo,M<:Model,E<:Tuple} <: AdvancedPS.AbstractGenericModel model::M diff --git a/src/mcmc/prior.jl b/src/mcmc/prior.jl index c4ec6c6f3..47fb75e07 100644 --- a/src/mcmc/prior.jl +++ b/src/mcmc/prior.jl @@ -5,21 +5,30 @@ Algorithm for sampling from the prior. """ struct Prior <: AbstractSampler end +function AbstractMCMC.step( + rng::Random.AbstractRNG, model::DynamicPPL.Model, sampler::Prior; kwargs... +) + accs = DynamicPPL.AccumulatorTuple(( + DynamicPPL.ValuesAsInModelAccumulator(true), + DynamicPPL.LogPriorAccumulator(), + DynamicPPL.LogLikelihoodAccumulator(), + )) + sampling_model = DynamicPPL.setleafcontext( + model, DynamicPPL.InitContext(rng, DynamicPPL.InitFromPrior()) + ) + vi = DynamicPPL.OnlyAccsVarInfo(accs) + _, vi = DynamicPPL.evaluate!!(sampling_model, vi) + return Transition(sampling_model, vi, nothing; reevaluate=false), (sampling_model, vi) +end + function AbstractMCMC.step( rng::Random.AbstractRNG, model::DynamicPPL.Model, sampler::Prior, - state=nothing; + state::Tuple{DynamicPPL.Model,DynamicPPL.Experimental.OnlyAccsVarInfo}; kwargs..., ) - vi = DynamicPPL.setaccs!!( - DynamicPPL.VarInfo(), - ( - DynamicPPL.ValuesAsInModelAccumulator(true), - DynamicPPL.LogPriorAccumulator(), - DynamicPPL.LogLikelihoodAccumulator(), - ), - ) - _, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromPrior()) - return Transition(model, vi, nothing; reevaluate=false), nothing + model, vi = state + _, vi = DynamicPPL.evaluate!!(model, vi) + return Transition(model, vi, nothing; reevaluate=false), (model, vi) end diff --git a/test/Project.toml b/test/Project.toml index 2b5b124b5..029fa1463 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -53,7 +53,7 @@ Combinatorics = "1" Distributions = "0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.38" +DynamicPPL = "0.39" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1" HypothesisTests = "0.11" @@ -77,3 +77,6 @@ StatsBase = "0.33, 0.34" StatsFuns = "0.9.5, 1" TimerOutputs = "0.5" julia = "1.10" + +[sources] +DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "py/ldf"} diff --git a/test/ad.jl b/test/ad.jl index 287c92834..1fa8003fd 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -94,7 +94,7 @@ encountered. """ struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <: - DynamicPPL.AbstractContext + DynamicPPL.AbstractParentContext child::ChildContext function ADTypeCheckContext(adbackend, child) @@ -108,7 +108,6 @@ end adtype(_::ADTypeCheckContext{ADType}) where {ADType} = ADType -DynamicPPL.NodeTrait(::ADTypeCheckContext) = DynamicPPL.IsParent() DynamicPPL.childcontext(c::ADTypeCheckContext) = c.child function DynamicPPL.setchildcontext(c::ADTypeCheckContext, child) return ADTypeCheckContext(adtype(c), child) @@ -138,14 +137,25 @@ Check that the element types in `vi` are compatible with the ADType of `context` Throw an `IncompatibleADTypeError` if an incompatible element type is encountered. """ function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo) + # If we are using InitFromPrior or InitFromUniform to generate new values, + # then the parameter type will be Any, so we should skip the check. + lc = DynamicPPL.leafcontext(context) + if lc isa DynamicPPL.InitContext{ + <:Any,<:Union{DynamicPPL.InitFromPrior,DynamicPPL.InitFromUniform} + } + return nothing + end + # Note that `get_param_eltype` will return `Any` with e.g. InitFromPrior or + # InitFromUniform, so this will fail. But on the bright side, you would never _really_ + # use AD with those strategies, so that's fine. The cases where you do want to + # use this are DefaultContext (i.e., old, slow, LogDensityFunction) and + # InitFromParams{<:VectorWithRanges} (i.e., new, fast, LogDensityFunction), and + # both of those give you sensible results for `get_param_eltype`. + param_eltype = DynamicPPL.get_param_eltype(vi, context) valids = valid_eltypes(context) - for val in vi[:] - valtype = typeof(val) - if !any(valtype .<: valids) - throw(IncompatibleADTypeError(valtype, adtype(context))) - end + if !(any(param_eltype .<: valids)) + throw(IncompatibleADTypeError(param_eltype, adtype(context))) end - return nothing end # A bunch of tilde_assume/tilde_observe methods that just call the same method on the child @@ -200,10 +210,10 @@ end @testset "Expected: $expected_adtype, Actual: $actual_adtype" begin if actual_adtype == expected_adtype # Check that this does not throw an error. - sample(contextualised_tm, sampler, 2) + sample(contextualised_tm, sampler, 2; check_model=false) else @test_throws AbstractWrongADBackendError sample( - contextualised_tm, sampler, 2 + contextualised_tm, sampler, 2; check_model=false ) end end diff --git a/test/mcmc/gibbs.jl b/test/mcmc/gibbs.jl index 1e3d5856c..af2995e44 100644 --- a/test/mcmc/gibbs.jl +++ b/test/mcmc/gibbs.jl @@ -159,15 +159,15 @@ end # It is modified by the capture_targets_and_algs function. targets_and_algs = Any[] - function capture_targets_and_algs(sampler, context) - if DynamicPPL.NodeTrait(context) == DynamicPPL.IsLeaf() - return nothing - end + function capture_targets_and_algs(sampler, context::DynamicPPL.AbstractParentContext) if context isa Inference.GibbsContext push!(targets_and_algs, (context.target_varnames, sampler)) end return capture_targets_and_algs(sampler, DynamicPPL.childcontext(context)) end + function capture_targets_and_algs(sampler, ::DynamicPPL.AbstractContext) + return nothing # Leaf context. + end # The methods that capture testing information for us. function AbstractMCMC.step(