diff --git a/docs/src/api.md b/docs/src/api.md index 0dcdd26e3..f949453a3 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -54,11 +54,10 @@ logjoint ### LogDensityProblems.jl interface -The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction` or `DynamicPPL.LogDensityFunctionWithGrad`. +The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`. ```@docs DynamicPPL.LogDensityFunction -DynamicPPL.LogDensityFunctionWithGrad ``` ## Condition and decondition diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 867e8ed80..dd048aad5 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -1,18 +1,39 @@ import DifferentiationInterface as DI """ - LogDensityFunction - -A callable representing a log density function of a `model`. -`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface, -but only to 0th-order, i.e. it is only possible to calculate the log density, -and not its gradient. If you need to calculate the gradient as well, you have -to construct a [`DynamicPPL.LogDensityFunctionWithGrad`](@ref) object. + LogDensityFunction( + model::Model, + varinfo::AbstractVarInfo=VarInfo(model), + context::AbstractContext=DefaultContext(); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing + ) + +A struct which contains a model, along with all the information necessary to: + + - calculate its log density at a given point; + - and if `adtype` is provided, calculate the gradient of the log density at + that point. + +At its most basic level, a LogDensityFunction wraps the model together with its +the type of varinfo to be used, as well as the evaluation context. These must +be known in order to calculate the log density (using +[`DynamicPPL.evaluate!!`](@ref)). + +If the `adtype` keyword argument is provided, then this struct will also store +the adtype along with other information for efficient calculation of the +gradient of the log density. Note that preparing a `LogDensityFunction` with an +AD type `AutoBackend()` requires the AD backend itself to have been loaded +(e.g. with `import Backend`). + +`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface. +If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a +concrete AD backend type, then `logdensity_and_gradient` is also implemented. # Fields $(FIELDS) # Examples + ```jldoctest julia> using Distributions @@ -48,66 +69,150 @@ julia> # This also respects the context in `model`. julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true + +julia> # If we also need to calculate the gradient, we can specify an AD backend. + import ForwardDiff, ADTypes + +julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff()); + +julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) +(-2.3378770664093453, [1.0]) ``` """ -struct LogDensityFunction{V,M,C} - "varinfo used for evaluation" - varinfo::V +struct LogDensityFunction{ + M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType} +} "model used for evaluation" model::M + "varinfo used for evaluation" + varinfo::V "context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable" context::C -end - -function LogDensityFunction( - model::Model, - varinfo::AbstractVarInfo=VarInfo(model), - context::Union{Nothing,AbstractContext}=nothing, -) - return LogDensityFunction(varinfo, model, context) -end + "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" + adtype::AD + "(internal use only) gradient preparation object for the model" + prep::Union{Nothing,DI.GradientPrep} + "(internal use only) whether a closure was used for the gradient preparation" + with_closure::Bool -# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`. -function getcontext(f::LogDensityFunction) - return f.context === nothing ? leafcontext(f.model.context) : f.context + function LogDensityFunction( + model::Model, + varinfo::AbstractVarInfo=VarInfo(model), + context::AbstractContext=leafcontext(model.context); + adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + ) + if adtype === nothing + prep = nothing + with_closure = false + else + # Get a set of dummy params to use for prep + x = map(identity, varinfo[:]) + with_closure = use_closure(adtype) + if with_closure + prep = DI.prepare_gradient( + x -> logdensity_at(x, model, varinfo, context), adtype, x + ) + else + prep = DI.prepare_gradient( + logdensity_at, + adtype, + x, + DI.Constant(model), + DI.Constant(varinfo), + DI.Constant(context), + ) + end + with_closure = with_closure + end + return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}( + model, varinfo, context, adtype, prep, with_closure + ) + end end """ - getmodel(f) + setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}) -Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. -""" -getmodel(f::DynamicPPL.LogDensityFunction) = f.model +Set the AD type used for evaluation of log density gradient in the given LogDensityFunction. +This function also performs preparation of the gradient, and sets the `prep` +and `with_closure` fields of the LogDensityFunction. -""" - setmodel(f, model[, adtype]) +If `adtype` is `nothing`, the `prep` field will be set to `nothing` as well. -Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. +This function returns a new LogDensityFunction with the updated AD type, i.e. it does +not mutate the input LogDensityFunction. """ -function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return Accessors.@set f.model = model +function setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}) + return if adtype === f.adtype + f # Avoid recomputing prep if not needed + else + LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype) + end end """ - getparams(f::LogDensityFunction) - -Return the parameters of the wrapped varinfo as a vector. + logdensity_at( + x::AbstractVector, + model::Model, + varinfo::AbstractVarInfo, + context::AbstractContext + ) + +Evaluate the log density of the given `model` at the given parameter values `x`, +using the given `varinfo` and `context`. Note that the `varinfo` argument is provided +only for its structure, in the sense that the parameters from the vector `x` are inserted into +it, and its own parameters are discarded. """ -getparams(f::LogDensityFunction) = f.varinfo[:] +function logdensity_at( + x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext +) + varinfo_new = unflatten(varinfo, x) + return getlogp(last(evaluate!!(model, varinfo_new, context))) +end + +### LogDensityProblems interface -# LogDensityProblems interface: logp (0th order) +function LogDensityProblems.capabilities( + ::Type{<:LogDensityFunction{M,V,C,Nothing}} +) where {M,V,C} + return LogDensityProblems.LogDensityOrder{0}() +end +function LogDensityProblems.capabilities( + ::Type{<:LogDensityFunction{M,V,C,AD}} +) where {M,V,C,AD<:ADTypes.AbstractADType} + return LogDensityProblems.LogDensityOrder{1}() +end function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) - context = getcontext(f) - vi_new = unflatten(f.varinfo, x) - return getlogp(last(evaluate!!(f.model, vi_new, context))) + return logdensity_at(x, f.model, f.varinfo, f.context) end -function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) - return LogDensityProblems.LogDensityOrder{0}() +function LogDensityProblems.logdensity_and_gradient( + f::LogDensityFunction{M,V,C,AD}, x::AbstractVector +) where {M,V,C,AD<:ADTypes.AbstractADType} + f.prep === nothing && + error("Gradient preparation not available; this should not happen") + x = map(identity, x) # Concretise type + return if f.with_closure + DI.value_and_gradient( + x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x + ) + else + DI.value_and_gradient( + logdensity_at, + f.prep, + f.adtype, + x, + DI.Constant(f.model), + DI.Constant(f.varinfo), + DI.Constant(f.context), + ) + end end + # TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) -# LogDensityProblems interface: gradient (1st order) +### Utils + """ use_closure(adtype::ADTypes.AbstractADType) @@ -139,75 +244,24 @@ use_closure(::ADTypes.AutoMooncake) = false use_closure(::ADTypes.AutoReverseDiff) = true """ - _flipped_logdensity(f::LogDensityFunction, x::AbstractVector) + getmodel(f) -This function is the same as `LogDensityProblems.logdensity(f, x)` but with the -arguments flipped. It is used in the 'constant' approach to DifferentiationInterface -(see `use_closure` for more information). +Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. """ -function _flipped_logdensity(x::AbstractVector, f::LogDensityFunction) - return LogDensityProblems.logdensity(f, x) -end +getmodel(f::DynamicPPL.LogDensityFunction) = f.model """ - LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType) - -A callable representing a log density function of a `model`. -`DynamicPPL.LogDensityFunctionWithGrad` implements the LogDensityProblems.jl -interface to 1st-order, meaning that you can both calculate the log density -using - - LogDensityProblems.logdensity(f, x) - -and its gradient using - - LogDensityProblems.logdensity_and_gradient(f, x) + setmodel(f, model[, adtype]) -where `f` is a `LogDensityFunctionWithGrad` object and `x` is a vector of parameters. +Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. +""" +function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) + return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype) +end -# Fields -$(FIELDS) """ -struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType} - ldf::LogDensityFunction{V,M,C} - adtype::TAD - prep::DI.GradientPrep - with_closure::Bool + getparams(f::LogDensityFunction) - function LogDensityFunctionWithGrad( - ldf::LogDensityFunction{V,M,C}, adtype::TAD - ) where {V,M,C,TAD} - # Get a set of dummy params to use for prep - x = map(identity, getparams(ldf)) - with_closure = use_closure(adtype) - if with_closure - prep = DI.prepare_gradient( - Base.Fix1(LogDensityProblems.logdensity, ldf), adtype, x - ) - else - prep = DI.prepare_gradient(_flipped_logdensity, adtype, x, DI.Constant(ldf)) - end - # Store the prep with the struct. We also store whether a closure was used because - # we need to know this when calling `DI.value_and_gradient`. In practice we could - # recalculate it, but this runs the risk of introducing inconsistencies. - return new{V,M,C,TAD}(ldf, adtype, prep, with_closure) - end -end -function LogDensityProblems.logdensity(f::LogDensityFunctionWithGrad) - return LogDensityProblems.logdensity(f.ldf) -end -function LogDensityProblems.capabilities(::Type{<:LogDensityFunctionWithGrad}) - return LogDensityProblems.LogDensityOrder{1}() -end -function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunctionWithGrad, x::AbstractVector -) - x = map(identity, x) # Concretise type - return if f.with_closure - DI.value_and_gradient( - Base.Fix1(LogDensityProblems.logdensity, f.ldf), f.prep, f.adtype, x - ) - else - DI.value_and_gradient(_flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf)) - end -end +Return the parameters of the wrapped varinfo as a vector. +""" +getparams(f::LogDensityFunction) = f.varinfo[:] diff --git a/test/ad.jl b/test/ad.jl index b8bb3ff3c..14d9b0dc4 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,4 +1,4 @@ -using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad +using DynamicPPL: LogDensityFunction @testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS @@ -10,11 +10,9 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad f = LogDensityFunction(m, varinfo) x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff - default_adtype = ADTypes.AutoForwardDiff() - ldf_with_grad = LogDensityFunctionWithGrad(f, default_adtype) - ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient( - ldf_with_grad, x - ) + ref_adtype = ADTypes.AutoForwardDiff() + ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype) + ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) @testset "$adtype" for adtype in [ AutoReverseDiff(; compile=false), @@ -33,20 +31,18 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad # Mooncake doesn't work with several combinations of SimpleVarInfo. if is_mooncake && is_1_11 && is_svi_vnv # https://github.com/compintell/Mooncake.jl/issues/470 - @test_throws ArgumentError LogDensityFunctionWithGrad(f, adtype) + @test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype) elseif is_mooncake && is_1_10 && is_svi_vnv # TODO: report upstream - @test_throws UndefRefError LogDensityFunctionWithGrad(f, adtype) + @test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype) elseif is_mooncake && is_1_10 && is_svi_od # TODO: report upstream - @test_throws Mooncake.MooncakeRuleCompilationError LogDensityFunctionWithGrad( - f, adtype + @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.setadtype( + ref_ldf, adtype ) else - ldf_with_grad = LogDensityFunctionWithGrad(f, adtype) - logp, grad = LogDensityProblems.logdensity_and_gradient( - ldf_with_grad, x - ) + ldf = DynamicPPL.setadtype(ref_ldf, adtype) + logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x) @test grad ≈ ref_grad @test logp ≈ ref_logp end @@ -90,8 +86,9 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad # Compiling the ReverseDiff tape used to fail here spl = Sampler(MyEmptyAlg()) vi = VarInfo(model) - ldf = LogDensityFunction(vi, model, SamplingContext(spl)) - ldf_grad = LogDensityFunctionWithGrad(ldf, AutoReverseDiff(; compile=true)) - @test LogDensityProblems.logdensity_and_gradient(ldf_grad, vi[:]) isa Any + ldf = LogDensityFunction( + model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true) + ) + @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any end end