Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ logjoint

### LogDensityProblems.jl interface

The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction` or `DynamicPPL.LogDensityFunctionWithGrad`.
The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) interface is also supported by wrapping a [`Model`](@ref) in a `DynamicPPL.LogDensityFunction`.

```@docs
DynamicPPL.LogDensityFunction
DynamicPPL.LogDensityFunctionWithGrad
```

## Condition and decondition
Expand Down
266 changes: 160 additions & 106 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,39 @@
import DifferentiationInterface as DI

"""
LogDensityFunction

A callable representing a log density function of a `model`.
`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface,
but only to 0th-order, i.e. it is only possible to calculate the log density,
and not its gradient. If you need to calculate the gradient as well, you have
to construct a [`DynamicPPL.LogDensityFunctionWithGrad`](@ref) object.
LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
context::AbstractContext=DefaultContext();
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
)

A struct which contains a model, along with all the information necessary to:

- calculate its log density at a given point;
- and if `adtype` is provided, calculate the gradient of the log density at
that point.

At its most basic level, a LogDensityFunction wraps the model together with its
the type of varinfo to be used, as well as the evaluation context. These must
be known in order to calculate the log density (using
[`DynamicPPL.evaluate!!`](@ref)).

If the `adtype` keyword argument is provided, then this struct will also store
the adtype along with other information for efficient calculation of the
gradient of the log density. Note that preparing a `LogDensityFunction` with an
AD type `AutoBackend()` requires the AD backend itself to have been loaded
(e.g. with `import Backend`).

`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface.
If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a
concrete AD backend type, then `logdensity_and_gradient` is also implemented.

# Fields
$(FIELDS)

# Examples

```jldoctest
julia> using Distributions

Expand Down Expand Up @@ -48,66 +69,150 @@ julia> # This also respects the context in `model`.

julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
true

julia> # If we also need to calculate the gradient, we can specify an AD backend.
import ForwardDiff, ADTypes

julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff());

julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
(-2.3378770664093453, [1.0])
```
"""
struct LogDensityFunction{V,M,C}
"varinfo used for evaluation"
varinfo::V
struct LogDensityFunction{
M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType}
}
"model used for evaluation"
model::M
"varinfo used for evaluation"
varinfo::V
"context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
context::C
end

function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
context::Union{Nothing,AbstractContext}=nothing,
)
return LogDensityFunction(varinfo, model, context)
end
"AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
adtype::AD
"(internal use only) gradient preparation object for the model"
prep::Union{Nothing,DI.GradientPrep}
"(internal use only) whether a closure was used for the gradient preparation"
with_closure::Bool

# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`.
function getcontext(f::LogDensityFunction)
return f.context === nothing ? leafcontext(f.model.context) : f.context
function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
context::AbstractContext=leafcontext(model.context);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
)
if adtype === nothing
prep = nothing
with_closure = false
else
# Get a set of dummy params to use for prep
x = map(identity, varinfo[:])
with_closure = use_closure(adtype)
if with_closure
prep = DI.prepare_gradient(
x -> logdensity_at(x, model, varinfo, context), adtype, x
)
else
prep = DI.prepare_gradient(
logdensity_at,
adtype,
x,
DI.Constant(model),
DI.Constant(varinfo),
DI.Constant(context),
)
end
with_closure = with_closure
end
return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}(
model, varinfo, context, adtype, prep, with_closure
)
end
end

"""
getmodel(f)
setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType})

Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
"""
getmodel(f::DynamicPPL.LogDensityFunction) = f.model
Set the AD type used for evaluation of log density gradient in the given LogDensityFunction.
This function also performs preparation of the gradient, and sets the `prep`
and `with_closure` fields of the LogDensityFunction.

"""
setmodel(f, model[, adtype])
If `adtype` is `nothing`, the `prep` field will be set to `nothing` as well.

Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
This function returns a new LogDensityFunction with the updated AD type, i.e. it does
not mutate the input LogDensityFunction.
"""
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return Accessors.@set f.model = model
function setadtype(f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType})
return if adtype === f.adtype
f # Avoid recomputing prep if not needed
else
LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype)
end
end

"""
getparams(f::LogDensityFunction)

Return the parameters of the wrapped varinfo as a vector.
logdensity_at(
x::AbstractVector,
model::Model,
varinfo::AbstractVarInfo,
context::AbstractContext
)

Evaluate the log density of the given `model` at the given parameter values `x`,
using the given `varinfo` and `context`. Note that the `varinfo` argument is provided
only for its structure, in the sense that the parameters from the vector `x` are inserted into
it, and its own parameters are discarded.
"""
getparams(f::LogDensityFunction) = f.varinfo[:]
function logdensity_at(
x::AbstractVector, model::Model, varinfo::AbstractVarInfo, context::AbstractContext
)
varinfo_new = unflatten(varinfo, x)
return getlogp(last(evaluate!!(model, varinfo_new, context)))
end

### LogDensityProblems interface

# LogDensityProblems interface: logp (0th order)
function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{M,V,C,Nothing}}
) where {M,V,C}
return LogDensityProblems.LogDensityOrder{0}()
end
function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{M,V,C,AD}}
) where {M,V,C,AD<:ADTypes.AbstractADType}
return LogDensityProblems.LogDensityOrder{1}()
end
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
context = getcontext(f)
vi_new = unflatten(f.varinfo, x)
return getlogp(last(evaluate!!(f.model, vi_new, context)))
return logdensity_at(x, f.model, f.varinfo, f.context)
end
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
return LogDensityProblems.LogDensityOrder{0}()
function LogDensityProblems.logdensity_and_gradient(
f::LogDensityFunction{M,V,C,AD}, x::AbstractVector
) where {M,V,C,AD<:ADTypes.AbstractADType}
f.prep === nothing &&
error("Gradient preparation not available; this should not happen")
x = map(identity, x) # Concretise type
return if f.with_closure
DI.value_and_gradient(
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
)
else
DI.value_and_gradient(
logdensity_at,
f.prep,
f.adtype,
x,
DI.Constant(f.model),
DI.Constant(f.varinfo),
DI.Constant(f.context),
)
end
end

# TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)?
LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f))

# LogDensityProblems interface: gradient (1st order)
### Utils

"""
use_closure(adtype::ADTypes.AbstractADType)

Expand Down Expand Up @@ -139,75 +244,24 @@ use_closure(::ADTypes.AutoMooncake) = false
use_closure(::ADTypes.AutoReverseDiff) = true

"""
_flipped_logdensity(f::LogDensityFunction, x::AbstractVector)
getmodel(f)

This function is the same as `LogDensityProblems.logdensity(f, x)` but with the
arguments flipped. It is used in the 'constant' approach to DifferentiationInterface
(see `use_closure` for more information).
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
"""
function _flipped_logdensity(x::AbstractVector, f::LogDensityFunction)
return LogDensityProblems.logdensity(f, x)
end
getmodel(f::DynamicPPL.LogDensityFunction) = f.model

"""
LogDensityFunctionWithGrad(ldf::DynamicPPL.LogDensityFunction, adtype::ADTypes.AbstractADType)

A callable representing a log density function of a `model`.
`DynamicPPL.LogDensityFunctionWithGrad` implements the LogDensityProblems.jl
interface to 1st-order, meaning that you can both calculate the log density
using

LogDensityProblems.logdensity(f, x)

and its gradient using

LogDensityProblems.logdensity_and_gradient(f, x)
setmodel(f, model[, adtype])

where `f` is a `LogDensityFunctionWithGrad` object and `x` is a vector of parameters.
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
"""
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype)
end

# Fields
$(FIELDS)
"""
struct LogDensityFunctionWithGrad{V,M,C,TAD<:ADTypes.AbstractADType}
ldf::LogDensityFunction{V,M,C}
adtype::TAD
prep::DI.GradientPrep
with_closure::Bool
getparams(f::LogDensityFunction)

function LogDensityFunctionWithGrad(
ldf::LogDensityFunction{V,M,C}, adtype::TAD
) where {V,M,C,TAD}
# Get a set of dummy params to use for prep
x = map(identity, getparams(ldf))
with_closure = use_closure(adtype)
if with_closure
prep = DI.prepare_gradient(
Base.Fix1(LogDensityProblems.logdensity, ldf), adtype, x
)
else
prep = DI.prepare_gradient(_flipped_logdensity, adtype, x, DI.Constant(ldf))
end
# Store the prep with the struct. We also store whether a closure was used because
# we need to know this when calling `DI.value_and_gradient`. In practice we could
# recalculate it, but this runs the risk of introducing inconsistencies.
return new{V,M,C,TAD}(ldf, adtype, prep, with_closure)
end
end
function LogDensityProblems.logdensity(f::LogDensityFunctionWithGrad)
return LogDensityProblems.logdensity(f.ldf)
end
function LogDensityProblems.capabilities(::Type{<:LogDensityFunctionWithGrad})
return LogDensityProblems.LogDensityOrder{1}()
end
function LogDensityProblems.logdensity_and_gradient(
f::LogDensityFunctionWithGrad, x::AbstractVector
)
x = map(identity, x) # Concretise type
return if f.with_closure
DI.value_and_gradient(
Base.Fix1(LogDensityProblems.logdensity, f.ldf), f.prep, f.adtype, x
)
else
DI.value_and_gradient(_flipped_logdensity, f.prep, f.adtype, x, DI.Constant(f.ldf))
end
end
Return the parameters of the wrapped varinfo as a vector.
"""
getparams(f::LogDensityFunction) = f.varinfo[:]
31 changes: 14 additions & 17 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
using DynamicPPL: LogDensityFunction

@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
Expand All @@ -10,11 +10,9 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
f = LogDensityFunction(m, varinfo)
x = DynamicPPL.getparams(f)
# Calculate reference logp + gradient of logp using ForwardDiff
default_adtype = ADTypes.AutoForwardDiff()
ldf_with_grad = LogDensityFunctionWithGrad(f, default_adtype)
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(
ldf_with_grad, x
)
ref_adtype = ADTypes.AutoForwardDiff()
ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype)
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)

@testset "$adtype" for adtype in [
AutoReverseDiff(; compile=false),
Expand All @@ -33,20 +31,18 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
# Mooncake doesn't work with several combinations of SimpleVarInfo.
if is_mooncake && is_1_11 && is_svi_vnv
# https://github.com/compintell/Mooncake.jl/issues/470
@test_throws ArgumentError LogDensityFunctionWithGrad(f, adtype)
@test_throws ArgumentError DynamicPPL.setadtype(ref_ldf, adtype)
elseif is_mooncake && is_1_10 && is_svi_vnv
# TODO: report upstream
@test_throws UndefRefError LogDensityFunctionWithGrad(f, adtype)
@test_throws UndefRefError DynamicPPL.setadtype(ref_ldf, adtype)
elseif is_mooncake && is_1_10 && is_svi_od
# TODO: report upstream
@test_throws Mooncake.MooncakeRuleCompilationError LogDensityFunctionWithGrad(
f, adtype
@test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.setadtype(
ref_ldf, adtype
)
else
ldf_with_grad = LogDensityFunctionWithGrad(f, adtype)
logp, grad = LogDensityProblems.logdensity_and_gradient(
ldf_with_grad, x
)
ldf = DynamicPPL.setadtype(ref_ldf, adtype)
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
@test grad ≈ ref_grad
@test logp ≈ ref_logp
end
Expand Down Expand Up @@ -90,8 +86,9 @@ using DynamicPPL: LogDensityFunction, LogDensityFunctionWithGrad
# Compiling the ReverseDiff tape used to fail here
spl = Sampler(MyEmptyAlg())
vi = VarInfo(model)
ldf = LogDensityFunction(vi, model, SamplingContext(spl))
ldf_grad = LogDensityFunctionWithGrad(ldf, AutoReverseDiff(; compile=true))
@test LogDensityProblems.logdensity_and_gradient(ldf_grad, vi[:]) isa Any
ldf = LogDensityFunction(
model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true)
)
@test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any
end
end
Loading