Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMCMCChainsExt = ["MCMCChains", "Statistics"]
DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"]
DynamicPPLMooncakeExt = ["Mooncake"]

Expand Down
38 changes: 38 additions & 0 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module DynamicPPLMCMCChainsExt

using DynamicPPL: DynamicPPL, AbstractPPL, AbstractMCMC
using MCMCChains: MCMCChains
using Statistics: mean

_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names

Expand Down Expand Up @@ -140,6 +141,43 @@ function AbstractMCMC.to_samples(
end
end

function AbstractMCMC.bundle_samples(
ts::Vector{<:DynamicPPL.ParamsWithStats},
model::DynamicPPL.Model,
spl::AbstractMCMC.AbstractSampler,
state,
chain_type::Type{MCMCChains.Chains};
save_state=false,
stats=missing,
sort_chain=false,
discard_initial=0,
thinning=1,
kwargs...,
)
bare_chain = AbstractMCMC.from_samples(MCMCChains.Chains, reshape(ts, :, 1))

# Add additional MCMC-specific info
info = bare_chain.info
if save_state
info = merge(info, (model=model, sampler=spl, samplerstate=state))
end
if !ismissing(stats)
info = merge(info, (start_time=stats.start, stop_time=stats.stop))
end

# Reconstruct the chain with the extra information
# Yeah, this is quite ugly. Blame MCMCChains.
chain = MCMCChains.Chains(
bare_chain.value.data,
names(bare_chain),
bare_chain.name_map;
info=info,
start=discard_initial + 1,
thin=thinning,
)
return sort_chain ? sort(chain) : chain
end

"""
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false)

Expand Down
2 changes: 1 addition & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,14 @@ include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")
include("values_as_in_model.jl")
include("experimental.jl")
include("chains.jl")
include("bijector.jl")

include("debug_utils.jl")
using .DebugUtils
include("test_utils.jl")

include("experimental.jl")
include("deprecated.jl")

if isdefined(Base.Experimental, :register_error_hint)
Expand Down
57 changes: 57 additions & 0 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,60 @@ function ParamsWithStats(
end
return ParamsWithStats(params, stats)
end

"""
ParamsWithStats(
param_vector::AbstractVector,
ldf::DynamicPPL.Experimental.FastLDF,
stats::NamedTuple=NamedTuple();
include_colon_eq::Bool=true,
include_log_probs::Bool=true,
)

Generate a `ParamsWithStats` by re-evaluating the given `ldf` with the provided
`param_vector`.

This method is intended to replace the old method of obtaining parameters and statistics
via `unflatten` plus re-evaluation. It is faster for two reasons:

1. It does not rely on `deepcopy`-ing the VarInfo object (this used to be mandatory as
otherwise re-evaluation would mutate the VarInfo, rendering it unusable for subsequent
MCMC iterations).
2. The re-evaluation is faster as it uses `OnlyAccsVarInfo`.
"""
function ParamsWithStats(
param_vector::AbstractVector,
ldf::DynamicPPL.Experimental.FastLDF,
stats::NamedTuple=NamedTuple();
include_colon_eq::Bool=true,
include_log_probs::Bool=true,
)
strategy = InitFromParams(
VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector),
nothing,
)
accs = if include_log_probs
(
DynamicPPL.LogPriorAccumulator(),
DynamicPPL.LogLikelihoodAccumulator(),
DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),
Comment on lines +170 to +172
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am looking at

"""
fast_ldf_accs(getlogdensity::Function)
Determine which accumulators are needed for fast evaluation with the given
`getlogdensity` function.
"""
fast_ldf_accs(::Function) = default_accumulators()
fast_ldf_accs(::typeof(getlogjoint_internal)) = default_accumulators()
function fast_ldf_accs(::typeof(getlogjoint))
return AccumulatorTuple((LogPriorAccumulator(), LogLikelihoodAccumulator()))
end
function fast_ldf_accs(::typeof(getlogprior_internal))
return AccumulatorTuple((LogPriorAccumulator(), LogJacobianAccumulator()))
end
fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
and wonder if there are time we need the LogJacobianAccumulator

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, right. I think it doesn't matter for the present purposes because we always want the output to not include the Jacobian term i.e. logprior and logjoint are 'as seen in the model'.

)
else
(DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),)
end
_, vi = DynamicPPL.Experimental.fast_evaluate!!(
ldf.model, strategy, AccumulatorTuple(accs)
)
params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
if include_log_probs
stats = merge(
stats,
(
logprior=DynamicPPL.getlogprior(vi),
loglikelihood=DynamicPPL.getloglikelihood(vi),
lp=DynamicPPL.getlogjoint(vi),
),
)
end
return ParamsWithStats(params, stats)
end
81 changes: 58 additions & 23 deletions src/fasteval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using DynamicPPL:
AccumulatorTuple,
InitContext,
InitFromParams,
AbstractInitStrategy,
LogJacobianAccumulator,
LogLikelihoodAccumulator,
LogPriorAccumulator,
Expand All @@ -28,6 +29,60 @@ using LogDensityProblems: LogDensityProblems
import DifferentiationInterface as DI
using Random: Random

"""
DynamicPPL.Experimental.fast_evaluate!!(
[rng::Random.AbstractRNG,]
model::Model,
strategy::AbstractInitStrategy,
accs::AccumulatorTuple, params::AbstractVector{<:Real}
)

Evaluate a model using parameters obtained via `strategy`, and only computing the results in
the provided accumulators.

It is assumed that the accumulators passed in have been initialised to appropriate values,
as this function will not reset them. The default constructors for each accumulator will do
this for you correctly.

Returns a tuple of the model's return value, plus an `OnlyAccsVarInfo`. Note that the `accs`
argument may be mutated (depending on how the accumulators are implemented); hence the `!!`
in the function name.
"""
@inline function fast_evaluate!!(
# Note that this `@inline` is mandatory for performance. If it's not inlined, it leads
# to extra allocations (even for trivial models) and much slower runtime.
rng::Random.AbstractRNG,
model::Model,
strategy::AbstractInitStrategy,
accs::AccumulatorTuple,
)
ctx = InitContext(rng, strategy)
model = DynamicPPL.setleafcontext(model, ctx)
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
# here.
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
# it _should_ do, but this is wrong regardless.
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
vi = if Threads.nthreads() > 1
param_eltype = DynamicPPL.get_param_eltype(strategy)
accs = map(accs) do acc
DynamicPPL.convert_eltype(float_type_with_fallback(param_eltype), acc)
end
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
else
OnlyAccsVarInfo(accs)
end
return DynamicPPL._evaluate!!(model, vi)
end
@inline function fast_evaluate!!(
model::Model, strategy::AbstractInitStrategy, accs::AccumulatorTuple
)
# This `@inline` is also mandatory for performance
return fast_evaluate!!(Random.default_rng(), model, strategy, accs)
end

"""
FastLDF(
model::Model,
Expand Down Expand Up @@ -213,31 +268,11 @@ struct FastLogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
varname_ranges::Dict{VarName,RangeAndLinked}
end
function (f::FastLogDensityAt)(params::AbstractVector{<:Real})
ctx = InitContext(
Random.default_rng(),
InitFromParams(
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
),
strategy = InitFromParams(
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
)
model = DynamicPPL.setleafcontext(f.model, ctx)
accs = fast_ldf_accs(f.getlogdensity)
# Calling `evaluate!!` would be fine, but would lead to an extra call to resetaccs!!,
# which is unnecessary. So we shortcircuit this by simply calling `_evaluate!!`
# directly. To preserve thread-safety we need to reproduce the ThreadSafeVarInfo logic
# here.
# TODO(penelopeysm): This should _not_ check Threads.nthreads(). I still don't know what
# it _should_ do, but this is wrong regardless.
# https://github.com/TuringLang/DynamicPPL.jl/issues/1086
vi = if Threads.nthreads() > 1
accs = map(
acc -> DynamicPPL.convert_eltype(float_type_with_fallback(eltype(params)), acc),
accs,
)
ThreadSafeVarInfo(OnlyAccsVarInfo(accs))
else
OnlyAccsVarInfo(accs)
end
_, vi = DynamicPPL._evaluate!!(model, vi)
_, vi = fast_evaluate!!(f.model, strategy, accs)
return f.getlogdensity(vi)
end

Expand Down
28 changes: 27 additions & 1 deletion test/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using DynamicPPL
using Distributions
using Test

@testset "ParamsWithStats" begin
@testset "ParamsWithStats from VarInfo" begin
@model function f(z)
x ~ Normal()
y := x + 1
Expand Down Expand Up @@ -66,4 +66,30 @@ using Test
end
end

@testset "ParamsWithStats from FastLDF" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
unlinked_vi = VarInfo(m)
@testset "$islinked" for islinked in (false, true)
vi = if islinked
DynamicPPL.link!!(unlinked_vi, m)
else
unlinked_vi
end
params = [x for x in vi[:]]

# Get the ParamsWithStats using FastLDF
fldf = DynamicPPL.Experimental.FastLDF(m, getlogjoint, vi)
ps = ParamsWithStats(params, fldf)

# Check that length of parameters is as expected
@test length(ps.params) == length(keys(vi))

# Iterate over all variables to check that their values match
for vn in keys(vi)
@test ps.params[vn] == vi[vn]
end
end
end
end

end # module