Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Distributions = "0.25.77"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.38"
DynamicPPL = "0.39"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3, 1"
Libtask = "0.9.3"
Expand All @@ -90,3 +90,6 @@ julia = "1.10.8"
[extras]
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"

[sources]
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "py/ldf"}
10 changes: 4 additions & 6 deletions src/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,9 @@ isgibbscomponent(::AdvancedHMC.AbstractHMCSampler) = true
isgibbscomponent(::AdvancedMH.MetropolisHastings) = true
isgibbscomponent(spl) = false

function can_be_wrapped(ctx::DynamicPPL.AbstractContext)
return DynamicPPL.NodeTrait(ctx) isa DynamicPPL.IsLeaf
end
can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(ctx.context)
can_be_wrapped(::DynamicPPL.AbstractContext) = true
can_be_wrapped(::DynamicPPL.AbstractParentContext) = false
can_be_wrapped(ctx::DynamicPPL.PrefixContext) = can_be_wrapped(DynamicPPL.childcontext(ctx))

# Basically like a `DynamicPPL.FixedContext` but
# 1. Hijacks the tilde pipeline to fix variables.
Expand Down Expand Up @@ -55,7 +54,7 @@ $(FIELDS)
"""
struct GibbsContext{
VNs<:Tuple{Vararg{VarName}},GVI<:Ref{<:AbstractVarInfo},Ctx<:DynamicPPL.AbstractContext
} <: DynamicPPL.AbstractContext
} <: DynamicPPL.AbstractParentContext
"""
the VarNames being sampled
"""
Expand Down Expand Up @@ -86,7 +85,6 @@ function GibbsContext(target_varnames, global_varinfo)
return GibbsContext(target_varnames, global_varinfo, DynamicPPL.DefaultContext())
end

DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(context::GibbsContext) = context.context
function DynamicPPL.setchildcontext(context::GibbsContext, childcontext)
return GibbsContext(
Expand Down
14 changes: 8 additions & 6 deletions src/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@ struct HMCState{
THam<:AHMC.Hamiltonian,
PhType<:AHMC.PhasePoint,
TAdapt<:AHMC.Adaptation.AbstractAdaptor,
L<:DynamicPPL.Experimental.FastLDF,
}
vi::TV
i::Int
kernel::TKernel
hamiltonian::THam
z::PhType
adaptor::TAdapt
ldf::L
end

###
Expand Down Expand Up @@ -196,7 +198,7 @@ function Turing.Inference.initialstep(
# Create a Hamiltonian.
metricT = getmetricT(spl)
metric = metricT(length(theta))
ldf = DynamicPPL.LogDensityFunction(
ldf = DynamicPPL.Experimental.FastLDF(
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype
)
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
Expand Down Expand Up @@ -225,8 +227,8 @@ function Turing.Inference.initialstep(
kernel = make_ahmc_kernel(spl, ϵ)
adaptor = AHMCAdaptor(spl, hamiltonian.metric; ϵ=ϵ)

transition = Transition(model, vi, NamedTuple())
state = HMCState(vi, 1, kernel, hamiltonian, z, adaptor)
transition = DynamicPPL.ParamsWithStats(theta, ldf, NamedTuple())
state = HMCState(vi, 1, kernel, hamiltonian, z, adaptor, ldf)

return transition, state
end
Expand Down Expand Up @@ -270,15 +272,15 @@ function AbstractMCMC.step(
end

# Compute next transition and state.
transition = Transition(model, vi, t)
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor)
transition = DynamicPPL.ParamsWithStats(t.z.θ, state.ldf, t.stat)
newstate = HMCState(vi, i, kernel, hamiltonian, t.z, state.adaptor, state.ldf)

return transition, newstate
end

function get_hamiltonian(model, spl, vi, state, n)
metric = gen_metric(n, spl, state)
ldf = DynamicPPL.LogDensityFunction(
ldf = DynamicPPL.Experimental.FastLDF(
model, DynamicPPL.getlogjoint_internal, vi; adtype=spl.adtype
)
lp_func = Base.Fix1(LogDensityProblems.logdensity, ldf)
Expand Down
1 change: 0 additions & 1 deletion src/mcmc/is.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ end
struct ISContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext
rng::R
end
DynamicPPL.NodeTrait(::ISContext) = DynamicPPL.IsLeaf()

function DynamicPPL.tilde_assume!!(
ctx::ISContext, dist::Distribution, vn::VarName, vi::AbstractVarInfo
Expand Down
1 change: 0 additions & 1 deletion src/mcmc/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@ end
struct MHContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext
rng::R
end
DynamicPPL.NodeTrait(::MHContext) = DynamicPPL.IsLeaf()

function DynamicPPL.tilde_assume!!(
context::MHContext, right::Distribution, vn::VarName, vi::AbstractVarInfo
Expand Down
1 change: 0 additions & 1 deletion src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
struct ParticleMCMCContext{R<:AbstractRNG} <: DynamicPPL.AbstractContext
rng::R
end
DynamicPPL.NodeTrait(::ParticleMCMCContext) = DynamicPPL.IsLeaf()

struct TracedModel{V<:AbstractVarInfo,M<:Model,E<:Tuple} <: AdvancedPS.AbstractGenericModel
model::M
Expand Down
31 changes: 20 additions & 11 deletions src/mcmc/prior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,30 @@ Algorithm for sampling from the prior.
"""
struct Prior <: AbstractSampler end

function AbstractMCMC.step(
rng::Random.AbstractRNG, model::DynamicPPL.Model, sampler::Prior; kwargs...
)
accs = DynamicPPL.AccumulatorTuple((
DynamicPPL.ValuesAsInModelAccumulator(true),
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
))
sampling_model = DynamicPPL.setleafcontext(
model, DynamicPPL.InitContext(rng, DynamicPPL.InitFromPrior())
)
vi = DynamicPPL.OnlyAccsVarInfo(accs)
_, vi = DynamicPPL.evaluate!!(sampling_model, vi)
return Transition(sampling_model, vi, nothing; reevaluate=false), (sampling_model, vi)
end

function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::Prior,
state=nothing;
state::Tuple{DynamicPPL.Model,DynamicPPL.Experimental.OnlyAccsVarInfo};
kwargs...,
)
vi = DynamicPPL.setaccs!!(
DynamicPPL.VarInfo(),
(
DynamicPPL.ValuesAsInModelAccumulator(true),
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
),
)
_, vi = DynamicPPL.init!!(model, vi, DynamicPPL.InitFromPrior())
return Transition(model, vi, nothing; reevaluate=false), nothing
model, vi = state
_, vi = DynamicPPL.evaluate!!(model, vi)
return Transition(model, vi, nothing; reevaluate=false), (model, vi)
end
5 changes: 4 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Combinatorics = "1"
Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.38"
DynamicPPL = "0.39"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1"
HypothesisTests = "0.11"
Expand All @@ -77,3 +77,6 @@ StatsBase = "0.33, 0.34"
StatsFuns = "0.9.5, 1"
TimerOutputs = "0.5"
julia = "1.10"

[sources]
DynamicPPL = {url = "https://github.com/TuringLang/DynamicPPL.jl", rev = "py/ldf"}
30 changes: 20 additions & 10 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ encountered.

"""
struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <:
DynamicPPL.AbstractContext
DynamicPPL.AbstractParentContext
child::ChildContext

function ADTypeCheckContext(adbackend, child)
Expand All @@ -108,7 +108,6 @@ end

adtype(_::ADTypeCheckContext{ADType}) where {ADType} = ADType

DynamicPPL.NodeTrait(::ADTypeCheckContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(c::ADTypeCheckContext) = c.child
function DynamicPPL.setchildcontext(c::ADTypeCheckContext, child)
return ADTypeCheckContext(adtype(c), child)
Expand Down Expand Up @@ -138,14 +137,25 @@ Check that the element types in `vi` are compatible with the ADType of `context`
Throw an `IncompatibleADTypeError` if an incompatible element type is encountered.
"""
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo)
# If we are using InitFromPrior or InitFromUniform to generate new values,
# then the parameter type will be Any, so we should skip the check.
lc = DynamicPPL.leafcontext(context)
if lc isa DynamicPPL.InitContext{
<:Any,<:Union{DynamicPPL.InitFromPrior,DynamicPPL.InitFromUniform}
}
return nothing
end
# Note that `get_param_eltype` will return `Any` with e.g. InitFromPrior or
# InitFromUniform, so this will fail. But on the bright side, you would never _really_
# use AD with those strategies, so that's fine. The cases where you do want to
# use this are DefaultContext (i.e., old, slow, LogDensityFunction) and
# InitFromParams{<:VectorWithRanges} (i.e., new, fast, LogDensityFunction), and
# both of those give you sensible results for `get_param_eltype`.
param_eltype = DynamicPPL.get_param_eltype(vi, context)
valids = valid_eltypes(context)
for val in vi[:]
valtype = typeof(val)
if !any(valtype .<: valids)
throw(IncompatibleADTypeError(valtype, adtype(context)))
end
if !(any(param_eltype .<: valids))
throw(IncompatibleADTypeError(param_eltype, adtype(context)))
end
return nothing
end

# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child
Expand Down Expand Up @@ -200,10 +210,10 @@ end
@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
if actual_adtype == expected_adtype
# Check that this does not throw an error.
sample(contextualised_tm, sampler, 2)
sample(contextualised_tm, sampler, 2; check_model=false)
else
@test_throws AbstractWrongADBackendError sample(
contextualised_tm, sampler, 2
contextualised_tm, sampler, 2; check_model=false
)
end
end
Expand Down
8 changes: 4 additions & 4 deletions test/mcmc/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ end
# It is modified by the capture_targets_and_algs function.
targets_and_algs = Any[]

function capture_targets_and_algs(sampler, context)
if DynamicPPL.NodeTrait(context) == DynamicPPL.IsLeaf()
return nothing
end
function capture_targets_and_algs(sampler, context::DynamicPPL.AbstractParentContext)
if context isa Inference.GibbsContext
push!(targets_and_algs, (context.target_varnames, sampler))
end
return capture_targets_and_algs(sampler, DynamicPPL.childcontext(context))
end
function capture_targets_and_algs(sampler, ::DynamicPPL.AbstractContext)
return nothing # Leaf context.
end

# The methods that capture testing information for us.
function AbstractMCMC.step(
Expand Down
Loading