diff --git a/HISTORY.md b/HISTORY.md index 775850973..85f195b95 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,10 @@ # DynamicPPL Changelog +## 0.39.11 + +Allow passing `accs::Union{NTuple{N,AbstractAccumulator},AccumulatorTuple}` into the `LogDensityFunction` constructor to specify custom accumulators to use when evaluating the model. +Previously, this was hard-coded. + ## 0.39.10 Rename the internal functions `matchingvalue` and `get_matching_type` to `convert_model_argument` and `promote_model_type_argument` respectively. diff --git a/Project.toml b/Project.toml index 988de202e..d03136c6b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.39.10" +version = "0.39.11" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/accumulators.jl b/src/accumulators.jl index 0208f19a5..ed5f28ec2 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -157,6 +157,7 @@ end AccumulatorTuple(accs::Vararg{AbstractAccumulator}) = AccumulatorTuple(accs) AccumulatorTuple(nt::NamedTuple) = AccumulatorTuple(tuple(nt...)) +AccumulatorTuple(at::AccumulatorTuple) = at # When showing with text/plain, leave out information about the wrapper AccumulatorTuple. Base.show(io::IO, mime::MIME"text/plain", at::AccumulatorTuple) = show(io, mime, at.nt) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index d67960d5f..5a61eb531 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -33,7 +33,8 @@ using Random: Random DynamicPPL.LogDensityFunction( model::Model, getlogdensity::Any=getlogjoint_internal, - varinfo::AbstractVarInfo=VarInfo(model); + varinfo::AbstractVarInfo=VarInfo(model) + accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=DynamicPPL.ldf_accs(getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) @@ -72,6 +73,12 @@ If you provide one of these functions, a `VarInfo` will be automatically created you provide a different function, you have to manually create a VarInfo and pass it as the third argument. +`accs` allows you to specify an `AccumulatorTuple` or a tuple of `AbstractAccumulator`s +which will be used _when evaluating the log density_`. (Note that the accumulators from the +`VarInfo` argument are discarded.) By default, this uses an internal function, +`DynamicPPL.ldf_accs`, which attempts to choose an appropriate set of accumulators based on +which kind of log-density is being calculated. + If the `adtype` keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the gradient of the log density. Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD @@ -152,6 +159,7 @@ struct LogDensityFunction{ ADP<:Union{Nothing,DI.GradientPrep}, # type of the vector passed to logdensity functions X<:AbstractVector, + AC<:AccumulatorTuple, } model::M adtype::AD @@ -160,11 +168,19 @@ struct LogDensityFunction{ _varname_ranges::Dict{VarName,RangeAndLinked} _adprep::ADP _dim::Int + _accs::AC function LogDensityFunction( model::Model, getlogdensity::Any=getlogjoint_internal, - varinfo::AbstractVarInfo=VarInfo(model); + # TODO(penelopeysm): It is a bit redundant to pass a VarInfo, as well as the + # accumulators, into here. The truth is that the VarInfo is used ONLY for generating + # the ranges and link status, so arguably we should only pass in a metadata; or when + # VNT is done, we should pass in only a VNT. + varinfo::AbstractVarInfo=VarInfo(model), + accs::Union{NTuple{<:Any,AbstractAccumulator},AccumulatorTuple}=ldf_accs( + getlogdensity + ); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) # Figure out which variable corresponds to which index, and @@ -187,13 +203,15 @@ struct LogDensityFunction{ end x = [val for val in varinfo[:]] dim = length(x) + # convert to AccumulatorTuple if needed + accs = AccumulatorTuple(accs) # Do AD prep if needed prep = if adtype === nothing nothing else # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) - args = (model, getlogdensity, all_iden_ranges, all_ranges) + args = (model, getlogdensity, all_iden_ranges, all_ranges, accs) if _use_closure(adtype) DI.prepare_gradient(LogDensityAt{Tlink}(args...), adtype, x) else @@ -214,8 +232,9 @@ struct LogDensityFunction{ typeof(all_iden_ranges), typeof(prep), typeof(x), + typeof(accs), }( - model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim + model, adtype, getlogdensity, all_iden_ranges, all_ranges, prep, dim, accs ) end end @@ -264,11 +283,11 @@ function logdensity_at( getlogdensity::Any, iden_varname_ranges::NamedTuple, varname_ranges::Dict{VarName,RangeAndLinked}, + accs::AccumulatorTuple, ) where {Tlink} strategy = InitFromParams( VectorWithRanges{Tlink}(iden_varname_ranges, varname_ranges, params), nothing ) - accs = ldf_accs(getlogdensity) _, vi = DynamicPPL.init!!(model, OnlyAccsVarInfo(accs), strategy) return getlogdensity(vi) end @@ -279,25 +298,30 @@ end getlogdensity::Any, iden_varname_ranges::NamedTuple, varname_ranges::Dict{VarName,RangeAndLinked}, + accs::AccumulatorTuple, ) where {Tlink} A callable struct that behaves in the same way as `logdensity_at`, but stores the model and other information internally. Having two separate functions/structs allows for better performance with AD backends. """ -struct LogDensityAt{Tlink,M<:Model,F,N<:NamedTuple} +struct LogDensityAt{Tlink,M<:Model,F,N<:NamedTuple,A<:AccumulatorTuple} model::M getlogdensity::F iden_varname_ranges::N varname_ranges::Dict{VarName,RangeAndLinked} + accs::A function LogDensityAt{Tlink}( model::M, getlogdensity::F, iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, - ) where {Tlink,M,F,N} - return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges) + accs::A, + ) where {Tlink,M,F,N,A} + return new{Tlink,M,F,N,A}( + model, getlogdensity, iden_varname_ranges, varname_ranges, accs + ) end end function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink} @@ -308,6 +332,7 @@ function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink} f.getlogdensity, f.iden_varname_ranges, f.varname_ranges, + f.accs, ) end @@ -321,6 +346,7 @@ function LogDensityProblems.logdensity( ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges, + ldf._accs, ) end @@ -333,7 +359,11 @@ function LogDensityProblems.logdensity_and_gradient( return if _use_closure(ldf.adtype) DI.value_and_gradient( LogDensityAt{Tlink}( - ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges + ldf.model, + ldf._getlogdensity, + ldf._iden_varname_ranges, + ldf._varname_ranges, + ldf._accs, ), ldf._adprep, ldf.adtype, @@ -350,6 +380,7 @@ function LogDensityProblems.logdensity_and_gradient( DI.Constant(ldf._getlogdensity), DI.Constant(ldf._iden_varname_ranges), DI.Constant(ldf._varname_ranges), + DI.Constant(ldf._accs), ) end end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 5160dab25..ceec4d02a 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -123,6 +123,38 @@ end x = vi[:] @test LogDensityProblems.logdensity(ldf, x) == sll.scale * logpdf(Normal(x[1]), 1.0) end + + @testset "Custom accumulators" begin + # Define an accumulator that always throws an error to test that custom + # accumulators can be used with LogDensityFunction + struct ErrorAccumulatorException <: Exception end + struct ErrorAccumulator <: DynamicPPL.AbstractAccumulator end + DynamicPPL.accumulator_name(::ErrorAccumulator) = :ERROR + DynamicPPL.accumulate_assume!!( + ::ErrorAccumulator, ::Any, ::Any, ::VarName, ::Distribution + ) = throw(ErrorAccumulatorException()) + DynamicPPL.accumulate_observe!!( + ::ErrorAccumulator, ::Distribution, ::Any, ::Union{VarName,Nothing} + ) = throw(ErrorAccumulatorException()) + DynamicPPL.reset(ea::ErrorAccumulator) = ea + Base.copy(ea::ErrorAccumulator) = ea + # Construct an LDF + @model function demo_error() + return x ~ Normal() + end + model = demo_error() + # check that passing accs as a tuple works + ldf = LogDensityFunction(model, getlogjoint, VarInfo(model), (ErrorAccumulator(),)) + @test_throws ErrorAccumulatorException LogDensityProblems.logdensity(ldf, [0.0]) + # check that passing accs as AccumulatorTuple also works + ldf = LogDensityFunction( + model, + getlogjoint, + VarInfo(model), + DynamicPPL.AccumulatorTuple(ErrorAccumulator()), + ) + @test_throws ErrorAccumulatorException LogDensityProblems.logdensity(ldf, [0.0]) + end end @testset "LogDensityFunction: Type stability" begin