diff --git a/HISTORY.md b/HISTORY.md index d367e9ad7..b59d8dd7f 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -32,20 +32,40 @@ Their semantics are the same as in Julia's `isapprox`; two values are equal if t You now need to explicitly pass a `VarInfo` argument to `check_model` and `check_model_and_trace`. Previously, these functions would generate a new VarInfo for you (using an optionally provided `rng`). -### Removal of `PriorContext` and `LikelihoodContext` - -A number of DynamicPPL's contexts have been removed, most notably `PriorContext` and `LikelihoodContext`. -Although these are not the only _exported_ contexts, we consider unlikely that anyone was using _other_ contexts manually: if you have a question about contexts _other_ than these, please continue reading the 'Internals' section below. +### Evaluating model log-probabilities in more detail Previously, during evaluation of a model, DynamicPPL only had the capability to store a _single_ log probability (`logp`) field. `DefaultContext`, `PriorContext`, and `LikelihoodContext` were used to control what this field represented: they would accumulate the log joint, log prior, or log likelihood, respectively. -Now, we have reworked DynamicPPL's `VarInfo` object such that it can track multiple log probabilities at once (see the 'Accumulators' section below). +In this version, we have overhauled this quite substantially. +The technical details of exactly _how_ this is done is covered in the 'Accumulators' section below, but the upshot is that the log prior, log likelihood, and log Jacobian terms (for any linked variables) are separately tracked. + +Specifically, you will want to use the following functions to access these log probabilities: + + - `getlogprior(varinfo)` to get the log prior. **Note:** This version introduces new, more consistent behaviour for this function, in that it always returns the log-prior of the values in the original, untransformed space, even if the `varinfo` has been linked. + - `getloglikelihood(varinfo)` to get the log likelihood. + - `getlogjoint(varinfo)` to get the log joint probability. **Note:** Similar to `getlogprior`, this function now always returns the log joint of the values in the original, untransformed space, even if the `varinfo` has been linked. + +If you are using linked VarInfos (e.g. if you are writing a sampler), you may find that you need to obtain the log probability of the variables in the transformed space. +To this end, you can use: + + - `getlogjac(varinfo)` to get the log Jacobian of the link transforms for any linked variables. + - `getlogprior_internal(varinfo)` to get the log prior of the variables in the transformed space. + - `getlogjoint_internal(varinfo)` to get the log joint probability of the variables in the transformed space. + +Since transformations only apply to random variables, the likelihood is unaffected by linking. + +### Removal of `PriorContext` and `LikelihoodContext` + +Following on from the above, a number of DynamicPPL's contexts have been removed, most notably `PriorContext` and `LikelihoodContext`. +Although these are not the only _exported_ contexts, we consider unlikely that anyone was using _other_ contexts manually: if you have a question about contexts _other_ than these, please continue reading the 'Internals' section below. + If you were evaluating a model with `PriorContext`, you can now just evaluate it with `DefaultContext`, and instead of calling `getlogp(varinfo)`, you can call `getlogprior(varinfo)` (and similarly for the likelihood). If you were constructing a `LogDensityFunction` with `PriorContext`, you can now stick to `DefaultContext`. `LogDensityFunction` now has an extra field, called `getlogdensity`, which represents a function that takes a `VarInfo` and returns the log density you want. -Thus, if you pass `getlogprior` as the value of this parameter, you will get the same behaviour as with `PriorContext`. +Thus, if you pass `getlogprior_internal` as the value of this parameter, you will get the same behaviour as with `PriorContext`. +(You should consider whether your use case needs the log prior in the transformed space, or the original space, and use (respectively) `getlogprior_internal` or `getlogprior` as needed.) The other case where one might use `PriorContext` was to use `@addlogprob!` to add to the log prior. Previously, this was accomplished by manually checking `__context__ isa DynamicPPL.PriorContext`. diff --git a/benchmarks/src/DynamicPPLBenchmarks.jl b/benchmarks/src/DynamicPPLBenchmarks.jl index 54a302a6f..8c5032ace 100644 --- a/benchmarks/src/DynamicPPLBenchmarks.jl +++ b/benchmarks/src/DynamicPPLBenchmarks.jl @@ -86,7 +86,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked:: vi = DynamicPPL.link(vi, model) end - f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi; adtype=adbackend) + f = DynamicPPL.LogDensityFunction( + model, DynamicPPL.getlogjoint_internal, vi; adtype=adbackend + ) # The parameters at which we evaluate f. θ = vi[:] diff --git a/docs/make.jl b/docs/make.jl index c69b72fb8..9c59cb06b 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -21,7 +21,9 @@ makedocs(; sitename="DynamicPPL", # The API index.html page is fairly large, and violates the default HTML page size # threshold of 200KiB, so we double that. - format=Documenter.HTML(; size_threshold=2^10 * 400), + format=Documenter.HTML(; + size_threshold=2^10 * 400, mathengine=Documenter.HTMLWriter.MathJax3() + ), modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)], pages=[ "Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"] diff --git a/docs/src/api.md b/docs/src/api.md index 180e8dfd4..9237943c7 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -367,6 +367,7 @@ DynamicPPL provides the following default accumulators. ```@docs LogPriorAccumulator +LogJacobianAccumulator LogLikelihoodAccumulator VariableOrderAccumulator ``` @@ -380,7 +381,12 @@ getlogp setlogp!! acclogp!! getlogjoint +getlogjoint_internal +getlogjac +setlogjac!! +acclogjac!! getlogprior +getlogprior_internal setlogprior!! acclogprior!! getloglikelihood diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c282939a2..15d39014e 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -50,6 +50,7 @@ export AbstractVarInfo, AbstractAccumulator, LogLikelihoodAccumulator, LogPriorAccumulator, + LogJacobianAccumulator, VariableOrderAccumulator, push!!, empty!!, @@ -58,10 +59,15 @@ export AbstractVarInfo, getlogjoint, getlogprior, getloglikelihood, + getlogjac, + getlogjoint_internal, + getlogprior_internal, setlogp!!, setlogprior!!, + setlogjac!!, setloglikelihood!!, acclogp!!, + acclogjac!!, acclogprior!!, accloglikelihood!!, resetlogp!!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 581ca829b..cf5ce5706 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -99,16 +99,34 @@ See also: [`getlogprior`](@ref), [`getloglikelihood`](@ref). """ getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi) +""" + getlogjoint_internal(vi::AbstractVarInfo) + +Return the log of the joint probability of the observed data and parameters as +they are stored internally in `vi`, including the log-Jacobian for any linked +parameters. + +In general, we have that: + +```julia +getlogjoint_internal(vi) == getlogjoint(vi) - getlogjac(vi) +``` +""" +getlogjoint_internal(vi::AbstractVarInfo) = + getlogprior(vi) + getloglikelihood(vi) - getlogjac(vi) + """ getlogp(vi::AbstractVarInfo) -Return a NamedTuple of the log prior and log likelihood probabilities. +Return a NamedTuple of the log prior, log Jacobian, and log likelihood probabilities. -The keys are called `logprior` and `loglikelihood`. If either one is not present in `vi` an -error will be thrown. +The keys are called `logprior`, `logjac`, and `loglikelihood`. If any of them +are not present in `vi` an error will be thrown. """ function getlogp(vi::AbstractVarInfo) - return (; logprior=getlogprior(vi), loglikelihood=getloglikelihood(vi)) + return (; + logprior=getlogprior(vi), logjac=getlogjac(vi), loglikelihood=getloglikelihood(vi) + ) end """ @@ -164,6 +182,30 @@ See also: [`getlogjoint`](@ref), [`getloglikelihood`](@ref), [`setlogprior!!`](@ """ getlogprior(vi::AbstractVarInfo) = getacc(vi, Val(:LogPrior)).logp +""" + getlogprior_internal(vi::AbstractVarInfo) + +Return the log of the prior probability of the parameters as stored internally +in `vi`. This includes the log-Jacobian for any linked parameters. + +In general, we have that: + +```julia +getlogprior_internal(vi) == getlogprior(vi) - getlogjac(vi) +``` +""" +getlogprior_internal(vi::AbstractVarInfo) = getlogprior(vi) - getlogjac(vi) + +""" + getlogjac(vi::AbstractVarInfo) + +Return the accumulated log-Jacobian term for any linked parameters in `vi`. The +Jacobian here is taken with respect to the forward (link) transform. + +See also: [`setlogjac!!`](@ref). +""" +getlogjac(vi::AbstractVarInfo) = getacc(vi, Val(:LogJacobian)).logjac + """ getloglikelihood(vi::AbstractVarInfo) @@ -196,6 +238,16 @@ See also: [`setloglikelihood!!`](@ref), [`setlogp!!`](@ref), [`getlogprior`](@re """ setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPriorAccumulator(logp)) +""" + setlogjac!!(vi::AbstractVarInfo, logjac) + +Set the accumulated log-Jacobian term for any linked parameters in `vi`. The +Jacobian here is taken with respect to the forward (link) transform. + +See also: [`getlogjac`](@ref), [`acclogjac!!`](@ref). +""" +setlogjac!!(vi::AbstractVarInfo, logjac) = setacc!!(vi, LogJacobianAccumulator(logjac)) + """ setloglikelihood!!(vi::AbstractVarInfo, logp) @@ -215,10 +267,13 @@ Set both the log prior and the log likelihood probabilities in `vi`. See also: [`setlogprior!!`](@ref), [`setloglikelihood!!`](@ref), [`getlogp`](@ref). """ function setlogp!!(vi::AbstractVarInfo, logp::NamedTuple{names}) where {names} - if !(names == (:logprior, :loglikelihood) || names == (:loglikelihood, :logprior)) - error("logp must have the fields logprior and loglikelihood and no other fields.") + if Set(names) != Set([:logprior, :logjac, :loglikelihood]) + error( + "The second argument to `setlogp!!` must be a NamedTuple with the fields logprior, logjac, and loglikelihood.", + ) end vi = setlogprior!!(vi, logp.logprior) + vi = setlogjac!!(vi, logp.logjac) vi = setloglikelihood!!(vi, logp.loglikelihood) return vi end @@ -226,7 +281,7 @@ end function setlogp!!(vi::AbstractVarInfo, logp::Number) return error(""" `setlogp!!(vi::AbstractVarInfo, logp::Number)` is no longer supported. Use - `setloglikelihood!!` and/or `setlogprior!!` instead. + `setloglikelihood!!`, `setlogjac!!`, and/or `setlogprior!!` instead. """) end @@ -306,6 +361,19 @@ function acclogprior!!(vi::AbstractVarInfo, logp) return map_accumulator!!(acc -> acc + LogPriorAccumulator(logp), vi, Val(:LogPrior)) end +""" + acclogjac!!(vi::AbstractVarInfo, logjac) + +Add `logjac` to the value of the log Jacobian in `vi`. + +See also: [`getlogjac`](@ref), [`setlogjac!!`](@ref). +""" +function acclogjac!!(vi::AbstractVarInfo, logjac) + return map_accumulator!!( + acc -> acc + LogJacobianAccumulator(logjac), vi, Val(:LogJacobian) + ) +end + """ accloglikelihood!!(vi::AbstractVarInfo, logp) @@ -368,6 +436,9 @@ function resetlogp!!(vi::AbstractVarInfo) if hasacc(vi, Val(:LogPrior)) vi = map_accumulator!!(zero, vi, Val(:LogPrior)) end + if hasacc(vi, Val(:LogJacobian)) + vi = map_accumulator!!(zero, vi, Val(:LogJacobian)) + end if hasacc(vi, Val(:LogLikelihood)) vi = map_accumulator!!(zero, vi, Val(:LogLikelihood)) end @@ -836,9 +907,12 @@ function link!!( x = vi[:] y, logjac = with_logabsdet_jacobian(b, x) - lp_new = getlogprior(vi) - logjac - vi_new = setlogprior!!(unflatten(vi, y), lp_new) - return settrans!!(vi_new, t) + # Set parameters and add the logjac term. + vi = unflatten(vi, y) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, logjac) + end + return settrans!!(vi, t) end function invlink!!( @@ -846,11 +920,16 @@ function invlink!!( ) b = t.bijector y = vi[:] - x, logjac = with_logabsdet_jacobian(b, y) - - lp_new = getlogprior(vi) + logjac - vi_new = setlogprior!!(unflatten(vi, x), lp_new) - return settrans!!(vi_new, NoTransformation()) + x, inv_logjac = with_logabsdet_jacobian(b, y) + + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + vi = unflatten(vi, x) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, inv_logjac) + end + return settrans!!(vi, NoTransformation()) end """ diff --git a/src/accumulators.jl b/src/accumulators.jl index 1e3e37e61..0dcf9c7cf 100644 --- a/src/accumulators.jl +++ b/src/accumulators.jl @@ -11,10 +11,21 @@ seen so far. An accumulator type `T <: AbstractAccumulator` must implement the following methods: - `accumulator_name(acc::T)` or `accumulator_name(::Type{T})` -- `accumulate_observe!!(acc::T, right, left, vn)` -- `accumulate_assume!!(acc::T, val, logjac, vn, right)` +- `accumulate_observe!!(acc::T, dist, val, vn)` +- `accumulate_assume!!(acc::T, val, logjac, vn, dist)` - `Base.copy(acc::T)` +In these functions: +- `val` is the new value of the random variable sampled from a distribution (always in + the original unlinked space), or the value on the left-hand side of an observe + statement. +- `dist` is the distribution on the RHS of the tilde statement. +- `vn` is the `VarName` that is on the left-hand side of the tilde-statement. If the + tilde-statement is a literal observation like `0.0 ~ Normal()`, then `vn` is `nothing`. +- `logjac` is the log determinant of the Jacobian of the link transformation, _if_ the + variable is stored as a linked value in the VarInfo. If the variable is stored in its + original, unlinked form, then `logjac` is zero. + To be able to work with multi-threading, it should also implement: - `split(acc::T)` - `combine(acc::T, acc2::T)` diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9e9a2d63d..786d7c913 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -123,8 +123,8 @@ end function assume(dist::Distribution, vn::VarName, vi) y = getindex_internal(vi, vn) f = from_maybe_linked_internal_transform(vi, vn, dist) - x, logjac = with_logabsdet_jacobian(f, y) - vi = accumulate_assume!!(vi, x, logjac, vn, dist) + x, inv_logjac = with_logabsdet_jacobian(f, y) + vi = accumulate_assume!!(vi, x, -inv_logjac, vn, dist) return x, vi end @@ -166,6 +166,6 @@ function assume( # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - vi = accumulate_assume!!(vi, r, -logjac, vn, dist) + vi = accumulate_assume!!(vi, r, logjac, vn, dist) return r, vi end diff --git a/src/default_accumulators.jl b/src/default_accumulators.jl index 418362e8f..d503b3e64 100644 --- a/src/default_accumulators.jl +++ b/src/default_accumulators.jl @@ -3,6 +3,10 @@ An accumulator that tracks the cumulative log prior during model execution. +Note that the log prior stored in here is always calculated based on unlinked +parameters, i.e., the value of `logp` is independent of whether tha VarInfo is +linked or not. + # Fields $(TYPEDFIELDS) """ @@ -19,6 +23,49 @@ Create a new `LogPriorAccumulator` accumulator with the log prior initialized to LogPriorAccumulator{T}() where {T<:Real} = LogPriorAccumulator(zero(T)) LogPriorAccumulator() = LogPriorAccumulator{LogProbType}() +""" + LogJacobianAccumulator{T<:Real} <: AbstractAccumulator + +An accumulator that tracks the cumulative log Jacobian (technically, +log(abs(det(J)))) during model execution. Specifically, J refers to the +Jacobian of the _link transform_, i.e., from the space of the original +distribution to unconstrained space. + +!!! note + This accumulator is only incremented if the variable is transformed by a + link function, i.e., if the VarInfo is linked (for the particular + variable that is currently being accumulated). If the variable is not + linked, the log Jacobian term will be 0. + + In general, for the forward Jacobian ``\\mathbf{J}`` corresponding to the + function ``\\mathbf{y} = f(\\mathbf{x})``, + + ```math + \\log(q(\\mathbf{y})) = \\log(p(\\mathbf{x})) - \\log (|\\mathbf{J}|) + ``` + + and correspondingly: + + ```julia + getlogjoint_internal(vi) = getlogjoint(vi) - getlogjac(vi) + ``` + +# Fields +$(TYPEDFIELDS) +""" +struct LogJacobianAccumulator{T<:Real} <: AbstractAccumulator + "the logabsdet of the link transform Jacobian" + logjac::T +end + +""" + LogJacobianAccumulator{T}() + +Create a new `LogJacobianAccumulator` accumulator with the log Jacobian initialized to zero. +""" +LogJacobianAccumulator{T}() where {T<:Real} = LogJacobianAccumulator(zero(T)) +LogJacobianAccumulator() = LogJacobianAccumulator{LogProbType}() + """ LogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator @@ -71,6 +118,7 @@ VariableOrderAccumulator(n) = VariableOrderAccumulator{typeof(n)}(n) VariableOrderAccumulator() = VariableOrderAccumulator{Int}() Base.copy(acc::LogPriorAccumulator) = acc +Base.copy(acc::LogJacobianAccumulator) = acc Base.copy(acc::LogLikelihoodAccumulator) = acc function Base.copy(acc::VariableOrderAccumulator) return VariableOrderAccumulator(acc.num_produce, copy(acc.order)) @@ -79,6 +127,9 @@ end function Base.show(io::IO, acc::LogPriorAccumulator) return print(io, "LogPriorAccumulator($(repr(acc.logp)))") end +function Base.show(io::IO, acc::LogJacobianAccumulator) + return print(io, "LogJacobianAccumulator($(repr(acc.logjac)))") +end function Base.show(io::IO, acc::LogLikelihoodAccumulator) return print(io, "LogLikelihoodAccumulator($(repr(acc.logp)))") end @@ -92,6 +143,9 @@ end # equality of hashes. Both of the below implementations are also different from the default # implementation for structs. Base.:(==)(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) = acc1.logp == acc2.logp +function Base.:(==)(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) + return acc1.logjac == acc2.logjac +end function Base.:(==)(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return acc1.logp == acc2.logp end @@ -102,6 +156,9 @@ end function Base.isequal(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) return isequal(acc1.logp, acc2.logp) end +function Base.isequal(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) + return isequal(acc1.logjac, acc2.logjac) +end function Base.isequal(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return isequal(acc1.logp, acc2.logp) end @@ -110,6 +167,9 @@ function Base.isequal(acc1::VariableOrderAccumulator, acc2::VariableOrderAccumul end Base.hash(acc::LogPriorAccumulator, h::UInt) = hash((LogPriorAccumulator, acc.logp), h) +function Base.hash(acc::LogJacobianAccumulator, h::UInt) + return hash((LogJacobianAccumulator, acc.logjac), h) +end function Base.hash(acc::LogLikelihoodAccumulator, h::UInt) return hash((LogLikelihoodAccumulator, acc.logp), h) end @@ -118,16 +178,21 @@ function Base.hash(acc::VariableOrderAccumulator, h::UInt) end accumulator_name(::Type{<:LogPriorAccumulator}) = :LogPrior +accumulator_name(::Type{<:LogJacobianAccumulator}) = :LogJacobian accumulator_name(::Type{<:LogLikelihoodAccumulator}) = :LogLikelihood accumulator_name(::Type{<:VariableOrderAccumulator}) = :VariableOrder split(::LogPriorAccumulator{T}) where {T} = LogPriorAccumulator(zero(T)) +split(::LogJacobianAccumulator{T}) where {T} = LogJacobianAccumulator(zero(T)) split(::LogLikelihoodAccumulator{T}) where {T} = LogLikelihoodAccumulator(zero(T)) split(acc::VariableOrderAccumulator) = copy(acc) function combine(acc::LogPriorAccumulator, acc2::LogPriorAccumulator) return LogPriorAccumulator(acc.logp + acc2.logp) end +function combine(acc::LogJacobianAccumulator, acc2::LogJacobianAccumulator) + return LogJacobianAccumulator(acc.logjac + acc2.logjac) +end function combine(acc::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return LogLikelihoodAccumulator(acc.logp + acc2.logp) end @@ -142,6 +207,9 @@ end function Base.:+(acc1::LogPriorAccumulator, acc2::LogPriorAccumulator) return LogPriorAccumulator(acc1.logp + acc2.logp) end +function Base.:+(acc1::LogJacobianAccumulator, acc2::LogJacobianAccumulator) + return LogJacobianAccumulator(acc1.logjac + acc2.logjac) +end function Base.:+(acc1::LogLikelihoodAccumulator, acc2::LogLikelihoodAccumulator) return LogLikelihoodAccumulator(acc1.logp + acc2.logp) end @@ -150,13 +218,19 @@ function increment(acc::VariableOrderAccumulator) end Base.zero(acc::LogPriorAccumulator) = LogPriorAccumulator(zero(acc.logp)) +Base.zero(acc::LogJacobianAccumulator) = LogJacobianAccumulator(zero(acc.logjac)) Base.zero(acc::LogLikelihoodAccumulator) = LogLikelihoodAccumulator(zero(acc.logp)) function accumulate_assume!!(acc::LogPriorAccumulator, val, logjac, vn, right) - return acc + LogPriorAccumulator(logpdf(right, val) + logjac) + return acc + LogPriorAccumulator(logpdf(right, val)) end accumulate_observe!!(acc::LogPriorAccumulator, right, left, vn) = acc +function accumulate_assume!!(acc::LogJacobianAccumulator, val, logjac, vn, right) + return acc + LogJacobianAccumulator(logjac) +end +accumulate_observe!!(acc::LogJacobianAccumulator, right, left, vn) = acc + accumulate_assume!!(acc::LogLikelihoodAccumulator, val, logjac, vn, right) = acc function accumulate_observe!!(acc::LogLikelihoodAccumulator, right, left, vn) # Note that it's important to use the loglikelihood function here, not logpdf, because @@ -174,6 +248,11 @@ accumulate_observe!!(acc::VariableOrderAccumulator, right, left, vn) = increment function Base.convert(::Type{LogPriorAccumulator{T}}, acc::LogPriorAccumulator) where {T} return LogPriorAccumulator(convert(T, acc.logp)) end +function Base.convert( + ::Type{LogJacobianAccumulator{T}}, acc::LogJacobianAccumulator +) where {T} + return LogJacobianAccumulator(convert(T, acc.logjac)) +end function Base.convert( ::Type{LogLikelihoodAccumulator{T}}, acc::LogLikelihoodAccumulator ) where {T} @@ -197,6 +276,9 @@ end function convert_eltype(::Type{T}, acc::LogPriorAccumulator) where {T} return LogPriorAccumulator(convert(T, acc.logp)) end +function convert_eltype(::Type{T}, acc::LogJacobianAccumulator) where {T} + return LogJacobianAccumulator(convert(T, acc.logjac)) +end function convert_eltype(::Type{T}, acc::LogLikelihoodAccumulator) where {T} return LogLikelihoodAccumulator(convert(T, acc.logp)) end @@ -206,6 +288,7 @@ function default_accumulators( ) where {FloatT,IntT} return AccumulatorTuple( LogPriorAccumulator{FloatT}(), + LogJacobianAccumulator{FloatT}(), LogLikelihoodAccumulator{FloatT}(), VariableOrderAccumulator{IntT}(), ) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 3c092c06b..3b790576a 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -18,7 +18,7 @@ is_supported(::ADTypes.AutoReverseDiff) = true """ LogDensityFunction( model::Model, - getlogdensity::Function=getlogjoint, + getlogdensity::Function=getlogjoint_internal, varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing ) @@ -29,10 +29,37 @@ A struct which contains a model, along with all the information necessary to: - and if `adtype` is provided, calculate the gradient of the log density at that point. -At its most basic level, a LogDensityFunction wraps the model together with a -function that specifies how to extract the log density, and the type of -VarInfo to be used. These must be known in order to calculate the log density -(using [`DynamicPPL.evaluate!!`](@ref)). +This information can be extracted using the LogDensityProblems.jl interface, +specifically, using `LogDensityProblems.logdensity` and +`LogDensityProblems.logdensity_and_gradient`. If `adtype` is nothing, then only +`logdensity` is implemented. If `adtype` is a concrete AD backend type, then +`logdensity_and_gradient` is also implemented. + +There are several options for `getlogdensity` that are 'supported' out of the +box: + +- [`getlogjoint_internal`](@ref): calculate the log joint, including the + log-Jacobian term for any variables that have been linked in the provided + VarInfo. +- [`getlogprior_internal`](@ref): calculate the log prior, including the + log-Jacobian term for any variables that have been linked in the provided + VarInfo. +- [`getlogjoint`](@ref): calculate the log joint in the model space, ignoring + any effects of linking +- [`getlogprior`](@ref): calculate the log prior in the model space, ignoring + any effects of linking +- [`getloglikelihood`](@ref): calculate the log likelihood (this is unaffected + by linking, since transforms are only applied to random variables) + +!!! note + By default, `LogDensityFunction` uses `getlogjoint_internal`, i.e., the + result of `LogDensityProblems.logdensity(f, x)` will depend on whether the + `LogDensityFunction` was created with a linked or unlinked VarInfo. This + is done primarily to ease interoperability with MCMC samplers. + +If you provide one of these functions, a `VarInfo` will be automatically created +for you. If you provide a different function, you have to manually create a +VarInfo and pass it as the third argument. If the `adtype` keyword argument is provided, then this struct will also store the adtype along with other information for efficient calculation of the @@ -40,10 +67,6 @@ gradient of the log density. Note that preparing a `LogDensityFunction` with an AD type `AutoBackend()` requires the AD backend itself to have been loaded (e.g. with `import Backend`). -`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface. -If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a -concrete AD backend type, then `logdensity_and_gradient` is also implemented. - # Fields $(FIELDS) @@ -74,7 +97,7 @@ julia> LogDensityProblems.dimension(f) 1 julia> # By default it uses `VarInfo` under the hood, but this is not necessary. - f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model)); + f = LogDensityFunction(model, getlogjoint_internal, SimpleVarInfo(model)); julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 @@ -99,7 +122,7 @@ struct LogDensityFunction{ } <: AbstractModel "model used for evaluation" model::M - "function to be called on `varinfo` to extract the log density. By default `getlogjoint`." + "function to be called on `varinfo` to extract the log density. By default `getlogjoint_internal`." getlogdensity::F "varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`." varinfo::V @@ -110,7 +133,7 @@ struct LogDensityFunction{ function LogDensityFunction( model::Model, - getlogdensity::Function=getlogjoint, + getlogdensity::Function=getlogjoint_internal, varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity); adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) @@ -180,7 +203,15 @@ function ldf_default_varinfo(::Model, getlogdensity::Function) return error(msg) end -ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) = VarInfo(model) +ldf_default_varinfo(model::Model, ::typeof(getlogjoint_internal)) = VarInfo(model) + +function ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) + return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogLikelihoodAccumulator())) +end + +function ldf_default_varinfo(model::Model, ::typeof(getlogprior_internal)) + return setaccs!!(VarInfo(model), (LogPriorAccumulator(), LogJacobianAccumulator())) +end function ldf_default_varinfo(model::Model, ::typeof(getlogprior)) return setaccs!!(VarInfo(model), (LogPriorAccumulator(),)) diff --git a/src/model.jl b/src/model.jl index 93e77eaec..dbbe0b85b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -995,6 +995,10 @@ Base.rand(model::Model) = rand(Random.default_rng(), NamedTuple, model) Return the log joint probability of variables `varinfo` for the probabilistic `model`. +Note that this probability always refers to the parameters in unlinked space, i.e., +the return value of `logjoint` does not depend on whether `VarInfo` has been linked +or not. + See [`logprior`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) @@ -1042,6 +1046,10 @@ end Return the log prior probability of variables `varinfo` for the probabilistic `model`. +Note that this probability always refers to the parameters in unlinked space, i.e., +the return value of `logprior` does not depend on whether `VarInfo` has been linked +or not. + See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 44882f91e..dea432022 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -74,6 +74,9 @@ function accumulate_assume!!( # T is the element type of the vectors that are the values of `acc.logps`. Usually # it's LogProbType. T = eltype(last(fieldtypes(eltype(acc.logps)))) + # Note that in only accumulating LogPrior, we effectively ignore logjac + # (since we want to return log densities that don't depend on the + # linking status of the VarInfo). subacc = accumulate_assume!!(LogPriorAccumulator{T}(), val, logjac, vn, right) push!(acc, vn, subacc.logp) end diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index abb93a0ab..0a2818e2a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -122,18 +122,18 @@ Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) -Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) +Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) julia> # (✓) Positive probability mass on negative numbers! - getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) + getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) -1.3678794411714423 julia> # While if we forget to indicate that it's transformed: vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) -SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) +SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0), VariableOrder = VariableOrderAccumulator(0, Dict{VarName, Int64}()))) julia> # (✓) No probability mass on negative numbers! - getlogjoint(last(DynamicPPL.evaluate!!(m, vi))) + getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi))) -Inf ``` @@ -476,7 +476,7 @@ function assume( f = to_maybe_linked_internal_transform(vi, vn, dist) value_raw, logjac = with_logabsdet_jacobian(f, value) vi = BangBang.push!!(vi, vn, value_raw, dist) - vi = accumulate_assume!!(vi, value, -logjac, vn, dist) + vi = accumulate_assume!!(vi, value, logjac, vn, dist) return value, vi end @@ -494,6 +494,7 @@ end istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) +istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = istrans(vi.varinfo) islinked(vi::SimpleVarInfo) = istrans(vi) @@ -619,8 +620,8 @@ function link!!( x = vi.values y, logjac = with_logabsdet_jacobian(b, x) vi_new = Accessors.@set(vi.values = y) - if hasacc(vi_new, Val(:LogPrior)) - vi_new = acclogprior!!(vi_new, -logjac) + if hasacc(vi_new, Val(:LogJacobian)) + vi_new = acclogjac!!(vi_new, logjac) end return settrans!!(vi_new, t) end @@ -632,10 +633,13 @@ function invlink!!( ) b = t.bijector y = vi.values - x, logjac = with_logabsdet_jacobian(b, y) + x, inv_logjac = with_logabsdet_jacobian(b, y) vi_new = Accessors.@set(vi.values = x) - if hasacc(vi_new, Val(:LogPrior)) - vi_new = acclogprior!!(vi_new, logjac) + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + if hasacc(vi_new, Val(:LogJacobian)) + vi_new = acclogjac!!(vi_new, inv_logjac) end return settrans!!(vi_new, NoTransformation()) end diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index d4f6f9a1d..1ac33a481 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,7 +4,8 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint, link +using DynamicPPL: + Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint_internal, link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: AbstractRNG, default_rng using Statistics: median @@ -224,7 +225,7 @@ function run_ad( benchmark::Bool=false, atol::AbstractFloat=100 * eps(), rtol::AbstractFloat=sqrt(eps()), - getlogdensity::Function=getlogjoint, + getlogdensity::Function=getlogjoint_internal, rng::AbstractRNG=default_rng(), varinfo::AbstractVarInfo=link(VarInfo(rng, model), model), params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 9b82cd8b4..5f0a6d3e5 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -201,6 +201,11 @@ function resetlogp!!(vi::ThreadSafeVarInfo) zero, vi.accs_by_thread[i], Val(:LogPrior) ) end + if hasacc(vi, Val(:LogJacobian)) + vi.accs_by_thread[i] = map_accumulator( + zero, vi.accs_by_thread[i], Val(:LogJacobian) + ) + end if hasacc(vi, Val(:LogLikelihood)) vi.accs_by_thread[i] = map_accumulator( zero, vi.accs_by_thread[i], Val(:LogLikelihood) diff --git a/src/transforming.jl b/src/transforming.jl index e3da0ff29..56f861cff 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -15,8 +15,8 @@ NodeTrait(::DynamicTransformationContext) = IsLeaf() function tilde_assume( ::DynamicTransformationContext{isinverse}, right, vn, vi ) where {isinverse} - r = vi[vn, right] - lp = Bijectors.logpdf_with_trans(right, r, !isinverse) + # vi[vn, right] always provides the value in unlinked space. + x = vi[vn, right] if istrans(vi, vn) isinverse || @warn "Trying to link an already transformed variable ($vn)" @@ -24,13 +24,11 @@ function tilde_assume( isinverse && @warn "Trying to invlink a non-transformed variable ($vn)" end - # Only transform if `!isinverse` since `vi[vn, right]` - # already performs the inverse transformation if it's transformed. - r_transformed = isinverse ? r : link_transform(right)(r) - if hasacc(vi, Val(:LogPrior)) - vi = acclogprior!!(vi, lp) - end - return r, setindex!!(vi, r_transformed, vn) + transform = isinverse ? identity : link_transform(right) + y, logjac = with_logabsdet_jacobian(transform, x) + vi = accumulate_assume!!(vi, x, logjac, vn, right) + vi = setindex!!(vi, y, vn) + return x, vi end function tilde_observe!!(::DynamicTransformationContext, right, left, vn, vi) @@ -53,21 +51,7 @@ function _transform!!( ) # To transform using DynamicTransformationContext, we evaluate the model using that as the leaf context: model = contextualize(model, setleafcontext(model.context, ctx)) - # but we do not need to use any accumulators other than LogPriorAccumulator - # (which is affected by the Jacobian of the transformation). - accs = getaccs(vi) - has_logprior = haskey(accs, Val(:LogPrior)) - if has_logprior - old_logprior = getacc(accs, Val(:LogPrior)) - vi = setaccs!!(vi, (old_logprior,)) - end vi = settrans!!(last(evaluate!!(model, vi)), t) - # Restore the accumulators. - if has_logprior - new_logprior = getacc(vi, Val(:LogPrior)) - accs = setacc!!(accs, new_logprior) - end - vi = setaccs!!(vi, accs) return vi end diff --git a/src/varinfo.jl b/src/varinfo.jl index d8233ae07..7b819c58f 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1148,8 +1148,8 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. setval!(md, yvec, vn) - if hasacc(vi, Val(:LogPrior)) - vi = acclogprior!!(vi, -logjac) + if hasacc(vi, Val(:LogJacobian)) + vi = acclogjac!!(vi, logjac) end return vi end @@ -1187,8 +1187,8 @@ function _link(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) md, logjac = _link_metadata!!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogPrior)) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogJacobian)) + new_varinfo = acclogjac!!(new_varinfo, logjac) end return new_varinfo end @@ -1203,8 +1203,8 @@ function _link(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) md, logjac = _link_metadata!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogPrior)) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogJacobian)) + new_varinfo = acclogjac!!(new_varinfo, logjac) end return new_varinfo end @@ -1351,10 +1351,13 @@ end function _invlink(model::Model, varinfo::VarInfo, vns) varinfo = deepcopy(varinfo) - md, logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) + md, inv_logjac = _invlink_metadata!!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogPrior)) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogJacobian)) + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + new_varinfo = acclogjac!!(new_varinfo, inv_logjac) end return new_varinfo end @@ -1367,10 +1370,13 @@ end function _invlink(model::Model, varinfo::NTVarInfo, vns::NamedTuple) varinfo = deepcopy(varinfo) - md, logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) + md, inv_logjac = _invlink_metadata!(model, varinfo, varinfo.metadata, vns) new_varinfo = VarInfo(md, varinfo.accs) - if hasacc(new_varinfo, Val(:LogPrior)) - new_varinfo = acclogprior!!(new_varinfo, -logjac) + if hasacc(new_varinfo, Val(:LogJacobian)) + # Mildly confusing: we need to _add_ the logjac of the inverse transform, + # because we are trying to remove the logjac of the forward transform + # that was previously accumulated when linking. + new_varinfo = acclogjac!!(new_varinfo, inv_logjac) end return new_varinfo end @@ -1382,7 +1388,7 @@ end vns::NamedTuple{vns_names}, ) where {metadata_names,vns_names} expr = quote - cumulative_logjac = zero(LogProbType) + cumulative_inv_logjac = zero(LogProbType) end mds = Expr(:tuple) for f in metadata_names @@ -1391,10 +1397,10 @@ end mds.args, quote begin - md, logjac = _invlink_metadata!!( + md, inv_logjac = _invlink_metadata!!( model, varinfo, metadata.$f, vns.$f ) - cumulative_logjac += logjac + cumulative_inv_logjac += inv_logjac md end end, @@ -1407,7 +1413,7 @@ end push!( expr.args, quote - (NamedTuple{$metadata_names}($mds), cumulative_logjac) + (NamedTuple{$metadata_names}($mds), cumulative_inv_logjac) end, ) return expr @@ -1415,7 +1421,7 @@ end function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_vns) vns = metadata.vns - cumulative_logjac = zero(LogProbType) + cumulative_inv_logjac = zero(LogProbType) # Construct the new transformed values, and keep track of their lengths. vals_new = map(vns) do vn @@ -1430,11 +1436,11 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ y = getindex_internal(varinfo, vn) dist = getdist(varinfo, vn) f = from_linked_internal_transform(varinfo, vn, dist) - x, logjac = with_logabsdet_jacobian(f, y) + x, inv_logjac = with_logabsdet_jacobian(f, y) # Vectorize value. xvec = tovec(x) # Accumulate the log-abs-det jacobian correction. - cumulative_logjac += logjac + cumulative_inv_logjac += inv_logjac # Mark as no longer transformed. settrans!!(varinfo, false, vn) # Return the vectorized transformed value. @@ -1459,25 +1465,25 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ metadata.dists, metadata.flags, ), - cumulative_logjac + cumulative_inv_logjac end function _invlink_metadata!!( ::Model, varinfo::VarInfo, metadata::VarNamedVector, target_vns ) vns = target_vns === nothing ? keys(metadata) : target_vns - cumulative_logjac = zero(LogProbType) + cumulative_inv_logjac = zero(LogProbType) for vn in vns transform = gettransform(metadata, vn) old_val = getindex_internal(metadata, vn) - new_val, logjac = with_logabsdet_jacobian(transform, old_val) + new_val, inv_logjac = with_logabsdet_jacobian(transform, old_val) # TODO(mhauru) We are calling a !! function but ignoring the return value. - cumulative_logjac += logjac + cumulative_inv_logjac += inv_logjac new_transform = from_vec_transform(new_val) metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) settrans!(metadata, false, vn) end - return metadata, cumulative_logjac + return metadata, cumulative_inv_logjac end # TODO(mhauru) The treatment of the case when some variables are linked and others are not diff --git a/test/accumulators.jl b/test/accumulators.jl index 5963ad8b5..506821c38 100644 --- a/test/accumulators.jl +++ b/test/accumulators.jl @@ -87,7 +87,9 @@ using DynamicPPL: vn = @varname(x) dist = Normal() @test accumulate_assume!!(LogPriorAccumulator(1.0), val, logjac, vn, dist) == - LogPriorAccumulator(1.0 + logjac + logpdf(dist, val)) + LogPriorAccumulator(1.0 + logpdf(dist, val)) + @test accumulate_assume!!(LogJacobianAccumulator(2.0), val, logjac, vn, dist) == + LogJacobianAccumulator(2.0 + logjac) @test accumulate_assume!!( LogLikelihoodAccumulator(1.0), val, logjac, vn, dist ) == LogLikelihoodAccumulator(1.0) @@ -101,6 +103,8 @@ using DynamicPPL: vn = @varname(x) @test accumulate_observe!!(LogPriorAccumulator(1.0), right, left, vn) == LogPriorAccumulator(1.0) + @test accumulate_observe!!(LogJacobianAccumulator(1.0), right, left, vn) == + LogJacobianAccumulator(1.0) @test accumulate_observe!!(LogLikelihoodAccumulator(1.0), right, left, vn) == LogLikelihoodAccumulator(1.0 + logpdf(right, left)) @test accumulate_observe!!(VariableOrderAccumulator(1), right, left, vn) == diff --git a/test/ad.jl b/test/ad.jl index 308894ada..371e79b06 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -30,7 +30,7 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos linked_varinfo = DynamicPPL.link(varinfo, m) - f = LogDensityFunction(m, getlogjoint, linked_varinfo) + f = LogDensityFunction(m, getlogjoint_internal, linked_varinfo) x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff @@ -52,17 +52,17 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest if is_mooncake && is_1_11 && is_svi_vnv # https://github.com/compintell/Mooncake.jl/issues/470 @test_throws ArgumentError DynamicPPL.LogDensityFunction( - m, getlogjoint, linked_varinfo; adtype=adtype + m, getlogjoint_internal, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_vnv # TODO: report upstream @test_throws UndefRefError DynamicPPL.LogDensityFunction( - m, getlogjoint, linked_varinfo; adtype=adtype + m, getlogjoint_internal, linked_varinfo; adtype=adtype ) elseif is_mooncake && is_1_10 && is_svi_od # TODO: report upstream @test_throws Mooncake.MooncakeRuleCompilationError DynamicPPL.LogDensityFunction( - m, getlogjoint, linked_varinfo; adtype=adtype + m, getlogjoint_internal, linked_varinfo; adtype=adtype ) else @test run_ad( @@ -113,7 +113,7 @@ using DynamicPPL.TestUtils.AD: run_ad, WithExpectedResult, NoTest spl = Sampler(MyEmptyAlg()) sampling_model = contextualize(model, SamplingContext(model.context)) ldf = LogDensityFunction( - sampling_model, getlogjoint; adtype=AutoReverseDiff(; compile=true) + sampling_model, getlogjoint_internal; adtype=AutoReverseDiff(; compile=true) ) x = ldf.varinfo[:] @test LogDensityProblems.logdensity_and_gradient(ldf, x) isa Any diff --git a/test/linking.jl b/test/linking.jl index b0c2dcb5c..cae101c72 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -84,8 +84,11 @@ end else DynamicPPL.link(vi, model) end - # Difference should just be the log-absdet-jacobian "correction". - @test DynamicPPL.getlogjoint(vi) - DynamicPPL.getlogjoint(vi_linked) ≈ log(2) + # Difference between the internal logjoints should just be the log-absdet-jacobian "correction". + @test DynamicPPL.getlogjoint_internal(vi) - + DynamicPPL.getlogjoint_internal(vi_linked) ≈ log(2) + # The non-internal logjoint should be the same since it doesn't depend on linking. + @test DynamicPPL.getlogjoint(vi) ≈ DynamicPPL.getlogjoint(vi_linked) @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) # Linked one should be working with a lower-dimensional representation. @test length(vi_linked[:]) < length(vi[:]) @@ -98,7 +101,12 @@ end end @test length(vi_invlinked[:]) == length(vi[:]) @test vi_invlinked[@varname(m), dist] ≈ LowerTriangular(vi[@varname(m), dist]) + # The non-internal logjoint should still be the same, again since + # it doesn't depend on linking. @test DynamicPPL.getlogjoint(vi_invlinked) ≈ DynamicPPL.getlogjoint(vi) + # The internal logjoint should also be the same as before the round-trip linking. + @test DynamicPPL.getlogjoint_internal(vi_invlinked) ≈ + DynamicPPL.getlogjoint_internal(vi) end end @@ -130,7 +138,7 @@ end end @test length(vi_linked[:]) == d * (d - 1) ÷ 2 # Should now include the log-absdet-jacobian correction. - @test !(getlogjoint(vi_linked) ≈ lp) + @test !(getlogjoint_internal(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -138,7 +146,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d^2 - @test getlogjoint(vi_invlinked) ≈ lp + @test getlogjoint_internal(vi_invlinked) ≈ lp end end end @@ -164,7 +172,7 @@ end end @test length(vi_linked[:]) == d - 1 # Should now include the log-absdet-jacobian correction. - @test !(getlogjoint(vi_linked) ≈ lp) + @test !(getlogjoint_internal(vi_linked) ≈ lp) # Invlinked. vi_invlinked = if mutable DynamicPPL.invlink!!(deepcopy(vi_linked), model) @@ -172,7 +180,7 @@ end DynamicPPL.invlink(vi_linked, model) end @test length(vi_invlinked[:]) == d - @test getlogjoint(vi_invlinked) ≈ lp + @test getlogjoint_internal(vi_invlinked) ≈ lp end end end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index c4d0d6beb..fbd868f71 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -26,8 +26,11 @@ end loglikelihood(model, vi) @testset "$(varinfo)" for varinfo in varinfos + # Note use of `getlogjoint` rather than `getlogjoint_internal` here ... logdensity = DynamicPPL.LogDensityFunction(model, getlogjoint, varinfo) θ = varinfo[:] + # ... because it has to match with `logjoint(model, vi)`, which always returns + # the unlinked value @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) @test LogDensityProblems.dimension(logdensity) == length(θ) end diff --git a/test/model.jl b/test/model.jl index daa3cc743..81f84e548 100644 --- a/test/model.jl +++ b/test/model.jl @@ -485,11 +485,18 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() DynamicPPL.untyped_simple_varinfo(model), ] @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + logjoint = getlogjoint(varinfo) # unlinked space varinfo_linked = DynamicPPL.link(varinfo, model) varinfo_linked_result = last( DynamicPPL.evaluate!!(model, deepcopy(varinfo_linked)) ) + # getlogjoint should return the same result as before it was linked @test getlogjoint(varinfo_linked) ≈ getlogjoint(varinfo_linked_result) + @test getlogjoint(varinfo_linked) ≈ logjoint + # getlogjoint_internal shouldn't + @test getlogjoint_internal(varinfo_linked) ≈ + getlogjoint_internal(varinfo_linked_result) + @test !isapprox(getlogjoint_internal(varinfo_linked), logjoint) end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index e300c651e..3cca1b5dc 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -89,38 +89,40 @@ @testset "link!! & invlink!! on $(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS values_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - @testset "$(typeof(vi))" for vi in ( - SimpleVarInfo(Dict()), - SimpleVarInfo(values_constrained), - SimpleVarInfo(DynamicPPL.VarNamedVector()), - DynamicPPL.typed_varinfo(model), + @testset "$name" for (name, vi) in ( + ("SVI{Dict}", SimpleVarInfo(Dict())), + ("SVI{NamedTuple}", SimpleVarInfo(values_constrained)), + ("SVI{VNV}", SimpleVarInfo(DynamicPPL.VarNamedVector())), + ("TypedVarInfo", DynamicPPL.typed_varinfo(model)), ) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) end vi = last(DynamicPPL.evaluate!!(model, vi)) - # `link!!` - vi_linked = link!!(deepcopy(vi), model) - lp_linked = getlogjoint(vi_linked) - values_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + # Calculate ground truth + lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true( model, values_constrained... ) - # Should result in the correct logjoint. + _, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, values_constrained... + ) + + # `link!!` + vi_linked = link!!(deepcopy(vi), model) + lp_unlinked = getlogjoint(vi_linked) + lp_linked = getlogjoint_internal(vi_linked) @test lp_linked ≈ lp_linked_true - # Should be approx. the same as the "lazy" transformation. - @test logjoint(model, vi_linked) ≈ lp_linked + @test lp_unlinked ≈ lp_unlinked_true + @test logjoint(model, vi_linked) ≈ lp_unlinked # `invlink!!` vi_invlinked = invlink!!(deepcopy(vi_linked), model) - lp_invlinked = getlogjoint(vi_invlinked) - lp_invlinked_true = DynamicPPL.TestUtils.logjoint_true( - model, values_constrained... - ) - # Should result in the correct logjoint. - @test lp_invlinked ≈ lp_invlinked_true - # Should be approx. the same as the "lazy" transformation. - @test logjoint(model, vi_invlinked) ≈ lp_invlinked + lp_unlinked = getlogjoint(vi_invlinked) + also_lp_unlinked = getlogjoint_internal(vi_invlinked) + @test lp_unlinked ≈ lp_unlinked_true + @test also_lp_unlinked ≈ lp_unlinked_true + @test logjoint(model, vi_invlinked) ≈ lp_unlinked # Should result in same values. @test all( @@ -143,10 +145,10 @@ end svi_vnv = SimpleVarInfo(vnv) - @testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in ( - svi_nt, - svi_dict, - svi_vnv, + @testset "$name" for (name, svi) in ( + ("NamedTuple", svi_nt), + ("Dict", svi_dict), + ("VarNamedVector", svi_vnv), # TODO(mhauru) Fix linked SimpleVarInfos to work with our test models. # DynamicPPL.settrans!!(deepcopy(svi_nt), true), # DynamicPPL.settrans!!(deepcopy(svi_dict), true), @@ -250,7 +252,7 @@ end # `getlogp` should be equal to the logjoint with log-absdet-jac correction. - lp = getlogjoint(svi) + lp = getlogjoint_internal(svi) # needs higher atol because of https://github.com/TuringLang/Bijectors.jl/issues/375 @test lp ≈ lp_true atol = 1.2e-5 end @@ -281,31 +283,36 @@ vi_linked = DynamicPPL.setindex!!(vi_linked, -rand(), vn) end - retval, vi_linked_result = DynamicPPL.evaluate!!(model, deepcopy(vi_linked)) + # NOTE: Evaluating a linked VarInfo, **specifically when the transformation + # is static**, will result in an invlinked VarInfo. This is because of + # `maybe_invlink_before_eval!`, which only invlinks if the transformation + # is static. (src/abstract_varinfo.jl) + retval, vi_unlinked_again = DynamicPPL.evaluate!!(model, deepcopy(vi_linked)) @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≠ DynamicPPL.tovec(retval.s) # `s` is unconstrained in original @test DynamicPPL.tovec( - DynamicPPL.getindex_internal(vi_linked_result, @varname(s)) + DynamicPPL.getindex_internal(vi_unlinked_again, @varname(s)) ) == DynamicPPL.tovec(retval.s) # `s` is constrained in result # `m` should not be transformed. @test vi_linked[@varname(m)] == retval.m - @test vi_linked_result[@varname(m)] == retval.m + @test vi_unlinked_again[@varname(m)] == retval.m - # Compare to truth. - retval_unconstrained, lp_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + # Get ground truths + retval_unconstrained, lp_linked_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( model, retval.s, retval.m ) + lp_unlinked_true = DynamicPPL.TestUtils.logjoint_true(model, retval.s, retval.m) @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(s))) ≈ DynamicPPL.tovec(retval_unconstrained.s) @test DynamicPPL.tovec(DynamicPPL.getindex_internal(vi_linked, @varname(m))) ≈ DynamicPPL.tovec(retval_unconstrained.m) - # The resulting varinfo should hold the correct logp. - lp = getlogjoint(vi_linked_result) - @test lp ≈ lp_true + # The unlinked varinfo should hold the unlinked logp. + lp_unlinked = getlogjoint(vi_unlinked_again) + @test getlogjoint(vi_unlinked_again) ≈ lp_unlinked_true end end end diff --git a/test/varinfo.jl b/test/varinfo.jl index dad54f024..16a9a857d 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -167,8 +167,9 @@ end vi = last(DynamicPPL.evaluate!!(m, deepcopy(vi))) @test getlogprior(vi) == lp_a + lp_b + @test getlogjac(vi) == 0.0 @test getloglikelihood(vi) == lp_c + lp_d - @test getlogp(vi) == (; logprior=lp_a + lp_b, loglikelihood=lp_c + lp_d) + @test getlogp(vi) == (; logprior=lp_a + lp_b, logjac=0.0, loglikelihood=lp_c + lp_d) @test getlogjoint(vi) == lp_a + lp_b + lp_c + lp_d @test get_num_produce(vi) == 2 @test begin @@ -183,17 +184,21 @@ end vi = setlogprior!!(vi, -1.0) getlogprior(vi) == -1.0 end + @test begin + vi = setlogjac!!(vi, -1.0) + getlogjac(vi) == -1.0 + end @test begin vi = setloglikelihood!!(vi, -1.0) getloglikelihood(vi) == -1.0 end @test begin - vi = setlogp!!(vi, (logprior=-3.0, loglikelihood=-3.0)) - getlogp(vi) == (; logprior=-3.0, loglikelihood=-3.0) + vi = setlogp!!(vi, (logprior=-3.0, logjac=-3.0, loglikelihood=-3.0)) + getlogp(vi) == (; logprior=-3.0, logjac=-3.0, loglikelihood=-3.0) end @test begin vi = acclogp!!(vi, (logprior=1.0, loglikelihood=1.0)) - getlogp(vi) == (; logprior=-2.0, loglikelihood=-2.0) + getlogp(vi) == (; logprior=-2.0, logjac=-3.0, loglikelihood=-2.0) end @test getlogp(setlogp!!(vi, getlogp(vi))) == getlogp(vi) @@ -206,7 +211,7 @@ end # need regex because 1.11 and 1.12 throw different errors (in 1.12 the # missing field is surrounded by backticks) @test_throws r"has no field `?LogLikelihood" getloglikelihood(vi) - @test_throws r"has no field `?LogLikelihood" getlogp(vi) + @test_throws r"has no field `?LogJacobian" getlogp(vi) @test_throws r"has no field `?LogLikelihood" getlogjoint(vi) @test_throws r"has no field `?VariableOrder" get_num_produce(vi) @test begin @@ -552,71 +557,52 @@ end end end - @testset "istrans" begin + @testset "logp evaluation on linked varinfo" begin @model demo_constrained() = x ~ truncated(Normal(); lower=0) model = demo_constrained() vn = @varname(x) dist = truncated(Normal(); lower=0) - ### `VarInfo` - # Need to run once since we can't specify that we want to _sample_ - # in the unconstrained space for `VarInfo` without having `vn` - # present in the `varinfo`. - - ## `untyped_varinfo` - vi = DynamicPPL.untyped_varinfo(model) + function test_linked_varinfo(model, vi) + # vn and dist are taken from the containing scope + vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) + f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) + x = f(DynamicPPL.getindex_internal(vi, vn)) + @test istrans(vi, vn) + @test getlogjoint_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getlogprior_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + @test getloglikelihood(vi) == 0.0 + @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) + @test getlogprior(vi) ≈ Bijectors.logpdf_with_trans(dist, x, false) + end ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) vi = DynamicPPL.settrans!!(vi, true, vn) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:Dict}` vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:VarNamedVector}` vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) - # Sample in unconstrained space. - vi = last(DynamicPPL.evaluate_and_sample!!(model, vi)) - f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) - x = f(DynamicPPL.getindex_internal(vi, vn)) - @test getlogjoint(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) + test_linked_varinfo(model, vi) end @testset "values_as" begin @@ -719,8 +705,8 @@ end lp = logjoint(model, varinfo) @test lp ≈ lp_true @test getlogjoint(varinfo) ≈ lp_true - lp_linked = getlogjoint(varinfo_linked) - @test lp_linked ≈ lp_linked_true + lp_linked_internal = getlogjoint_internal(varinfo_linked) + @test lp_linked_internal ≈ lp_linked_true # TODO: Compare values once we are no longer working with `NamedTuple` for # the true values, e.g. `value_true`. @@ -732,6 +718,7 @@ end ) @test length(varinfo_invlinked[:]) == length(varinfo[:]) @test getlogjoint(varinfo_invlinked) ≈ lp_true + @test getlogjoint_internal(varinfo_invlinked) ≈ lp_true end end end