Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# DynamicPPL Changelog

## 0.39.11

Allow passing `accs::Union{NTuple{N,AbstractAccumulator},AccumulatorTuple}` into the `LogDensityFunction` constructor to specify custom accumulators to use when evaluating the model.
Previously, this was hard-coded.

## 0.39.10

Rename the internal functions `matchingvalue` and `get_matching_type` to `convert_model_argument` and `promote_model_type_argument` respectively.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.39.10"
version = "0.39.11"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
1 change: 1 addition & 0 deletions src/accumulators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ end

AccumulatorTuple(accs::Vararg{AbstractAccumulator}) = AccumulatorTuple(accs)
AccumulatorTuple(nt::NamedTuple) = AccumulatorTuple(tuple(nt...))
AccumulatorTuple(at::AccumulatorTuple) = at

# When showing with text/plain, leave out information about the wrapper AccumulatorTuple.
Base.show(io::IO, mime::MIME"text/plain", at::AccumulatorTuple) = show(io, mime, at.nt)
Expand Down
49 changes: 40 additions & 9 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ using Random: Random
DynamicPPL.LogDensityFunction(
model::Model,
getlogdensity::Any=getlogjoint_internal,
varinfo::AbstractVarInfo=VarInfo(model);
varinfo::AbstractVarInfo=VarInfo(model)
accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=DynamicPPL.ldf_accs(getlogdensity);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
)

Expand Down Expand Up @@ -72,6 +73,12 @@ If you provide one of these functions, a `VarInfo` will be automatically created
you provide a different function, you have to manually create a VarInfo and pass it as the
third argument.

`accs` allows you to specify an `AccumulatorTuple` or a tuple of `AbstractAccumulator`s
which will be used _when evaluating the log density_`. (Note that the accumulators from the
`VarInfo` argument are discarded.) By default, this uses an internal function,
`DynamicPPL.ldf_accs`, which attempts to choose an appropriate set of accumulators based on
which kind of log-density is being calculated.

If the `adtype` keyword argument is provided, then this struct will also store the adtype
along with other information for efficient calculation of the gradient of the log density.
Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD
Expand Down Expand Up @@ -152,6 +159,7 @@ struct LogDensityFunction{
ADP<:Union{Nothing,DI.GradientPrep},
# type of the vector passed to logdensity functions
X<:AbstractVector,
AC<:AccumulatorTuple,
}
model::M
adtype::AD
Expand All @@ -160,11 +168,19 @@ struct LogDensityFunction{
_varname_ranges::Dict{VarName,RangeAndLinked}
_adprep::ADP
_dim::Int
_accs::AC

function LogDensityFunction(
model::Model,
getlogdensity::Any=getlogjoint_internal,
varinfo::AbstractVarInfo=VarInfo(model);
# TODO(penelopeysm): It is a bit redundant to pass a VarInfo, as well as the
# accumulators, into here. The truth is that the VarInfo is used ONLY for generating
# the ranges and link status, so arguably we should only pass in a metadata; or when
# VNT is done, we should pass in only a VNT.
varinfo::AbstractVarInfo=VarInfo(model),
accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=ldf_accs(
getlogdensity
);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
)
# Figure out which variable corresponds to which index, and
Expand All @@ -187,13 +203,15 @@ struct LogDensityFunction{
end
x = [val for val in varinfo[:]]
dim = length(x)
# convert to AccumulatorTuple if needed
accs = AccumulatorTuple(accs)
# Do AD prep if needed
prep = if adtype === nothing
nothing
else
# Make backend-specific tweaks to the adtype
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
args = (model, getlogdensity, all_iden_ranges, all_ranges)
args = (model, getlogdensity, all_iden_ranges, all_ranges, accs)
if _use_closure(adtype)
DI.prepare_gradient(LogDensityAt{Tlink}(args...), adtype, x)
else
Expand All @@ -214,8 +232,9 @@ struct LogDensityFunction{
typeof(all_iden_ranges),
typeof(prep),
typeof(x),
typeof(accs),
}(
model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim
model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim, accs
)
end
end
Expand Down Expand Up @@ -264,11 +283,11 @@ function logdensity_at(
getlogdensity::Any,
iden_varname_ranges::NamedTuple,
varname_ranges::Dict{VarName,RangeAndLinked},
accs::AccumulatorTuple,
) where {Tlink}
strategy = InitFromParams(
VectorWithRanges{Tlink}(iden_varname_ranges, varname_ranges, params), nothing
)
accs = ldf_accs(getlogdensity)
_, vi = DynamicPPL.init!!(model, OnlyAccsVarInfo(accs), strategy)
return getlogdensity(vi)
end
Expand All @@ -279,25 +298,30 @@ end
getlogdensity::Any,
iden_varname_ranges::NamedTuple,
varname_ranges::Dict{VarName,RangeAndLinked},
accs::AccumulatorTuple,
) where {Tlink}

A callable struct that behaves in the same way as `logdensity_at`, but stores the model and
other information internally. Having two separate functions/structs allows for better
performance with AD backends.
"""
struct LogDensityAt{Tlink,M<:Model,F,N<:NamedTuple}
struct LogDensityAt{Tlink,M<:Model,F,N<:NamedTuple,A<:AccumulatorTuple}
model::M
getlogdensity::F
iden_varname_ranges::N
varname_ranges::Dict{VarName,RangeAndLinked}
accs::A

function LogDensityAt{Tlink}(
model::M,
getlogdensity::F,
iden_varname_ranges::N,
varname_ranges::Dict{VarName,RangeAndLinked},
) where {Tlink,M,F,N}
return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges)
accs::A,
) where {Tlink,M,F,N,A}
return new{Tlink,M,F,N,A}(
model, getlogdensity, iden_varname_ranges, varname_ranges, accs
)
end
end
function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
Expand All @@ -308,6 +332,7 @@ function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
f.getlogdensity,
f.iden_varname_ranges,
f.varname_ranges,
f.accs,
)
end

Expand All @@ -321,6 +346,7 @@ function LogDensityProblems.logdensity(
ldf._getlogdensity,
ldf._iden_varname_ranges,
ldf._varname_ranges,
ldf._accs,
)
end

Expand All @@ -333,7 +359,11 @@ function LogDensityProblems.logdensity_and_gradient(
return if _use_closure(ldf.adtype)
DI.value_and_gradient(
LogDensityAt{Tlink}(
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
ldf.model,
ldf._getlogdensity,
ldf._iden_varname_ranges,
ldf._varname_ranges,
ldf._accs,
),
ldf._adprep,
ldf.adtype,
Expand All @@ -350,6 +380,7 @@ function LogDensityProblems.logdensity_and_gradient(
DI.Constant(ldf._getlogdensity),
DI.Constant(ldf._iden_varname_ranges),
DI.Constant(ldf._varname_ranges),
DI.Constant(ldf._accs),
)
end
end
Expand Down
32 changes: 32 additions & 0 deletions test/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,38 @@ end
x = vi[:]
@test LogDensityProblems.logdensity(ldf, x) == sll.scale * logpdf(Normal(x[1]), 1.0)
end

@testset "Custom accumulators" begin
# Define an accumulator that always throws an error to test that custom
# accumulators can be used with LogDensityFunction
struct ErrorAccumulatorException <: Exception end
struct ErrorAccumulator <: DynamicPPL.AbstractAccumulator end
DynamicPPL.accumulator_name(::ErrorAccumulator) = :ERROR
DynamicPPL.accumulate_assume!!(
::ErrorAccumulator, ::Any, ::Any, ::VarName, ::Distribution
) = throw(ErrorAccumulatorException())
DynamicPPL.accumulate_observe!!(
::ErrorAccumulator, ::Distribution, ::Any, ::Union{VarName,Nothing}
) = throw(ErrorAccumulatorException())
DynamicPPL.reset(ea::ErrorAccumulator) = ea
Base.copy(ea::ErrorAccumulator) = ea
# Construct an LDF
@model function demo_error()
return x ~ Normal()
end
model = demo_error()
# check that passing accs as a tuple works
ldf = LogDensityFunction(model, getlogjoint, VarInfo(model), (ErrorAccumulator(),))
@test_throws ErrorAccumulatorException LogDensityProblems.logdensity(ldf, [0.0])
# check that passing accs as AccumulatorTuple also works
ldf = LogDensityFunction(
model,
getlogjoint,
VarInfo(model),
DynamicPPL.AccumulatorTuple(ErrorAccumulator()),
)
@test_throws ErrorAccumulatorException LogDensityProblems.logdensity(ldf, [0.0])
end
end

@testset "LogDensityFunction: Type stability" begin
Expand Down