diff --git a/HISTORY.md b/HISTORY.md index 777f3f32c..eca4be995 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -49,6 +49,8 @@ Removed the method `returned(::Model, values, keys)`; please use `returned(::Mod The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space. This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function). +Improved performance of transformations of univariate distributions' samples to and from their vectorised forms. + ## 0.38.9 Remove warning when using Enzyme as the AD backend. diff --git a/Project.toml b/Project.toml index 1b5e52492..74c5b1cc1 100644 --- a/Project.toml +++ b/Project.toml @@ -17,8 +17,10 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -62,10 +64,12 @@ DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" ForwardDiff = "0.10.12, 1" InteractiveUtils = "1" +InverseFunctions = "0.1.17" JET = "0.9, 0.10, 0.11" KernelAbstractions = "0.9.33" LinearAlgebra = "1.6" LogDensityProblems = "2" +LogExpFunctions = "0.3.29" MCMCChains = "6, 7" MacroTools = "0.5.6" MarginalLogDensities = "0.4.3" diff --git a/src/utils.jl b/src/utils.jl index 2d7b0404f..ddbc96742 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,6 @@ +using LogExpFunctions: LogExpFunctions +using InverseFunctions: InverseFunctions, inverse + # singleton for indicating if no default arguments are present struct NoDefault end const NO_DEFAULT = NoDefault() @@ -261,42 +264,6 @@ invlink_transform(dist) = inverse(link_transform(dist)) # Helper functions for vectorize/reconstruct values # ##################################################### -""" - UnwrapSingletonTransform(input_size::InSize) - -A transformation that unwraps a singleton array, returning a scalar. - -The `input_size` field is the expected size of the input. In practice this only determines -the number of indices, since all dimensions must be 1 for a singleton. `input_size` is used -to check the validity of the input, but also to determine the correct inverse operation. - -By default `input_size` is `(1,)`, in which case `tovec` is the inverse. -""" -struct UnwrapSingletonTransform{InSize} <: Bijectors.Bijector - input_size::InSize -end - -UnwrapSingletonTransform() = UnwrapSingletonTransform((1,)) - -function (f::UnwrapSingletonTransform)(x) - if size(x) != f.input_size - throw(DimensionMismatch("Expected input of size $(f.input_size), got $(size(x))")) - end - return only(x) -end - -function Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x) - return f(x), zero(LogProbType) -end - -function Bijectors.with_logabsdet_jacobian( - inv_f::Bijectors.Inverse{<:UnwrapSingletonTransform}, x -) - f = inv_f.orig - result = reshape([x], f.input_size) - return result, zero(LogProbType) -end - """ ReshapeTransform(input_size::InSize, output_size::OutSize) @@ -370,6 +337,178 @@ function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y) ) end +struct Only end +struct NotOnly end +(::Only)(x) = x[] +(::NotOnly)(y) = [y] +function Bijectors.with_logabsdet_jacobian(::Only, x::AbstractVector{T}) where {T<:Real} + return (x[], zero(T)) +end +Bijectors.with_logabsdet_jacobian(::Only, x::AbstractVector) = (x[], zero(LogProbType)) +InverseFunctions.inverse(::Only) = NotOnly() +InverseFunctions.inverse(::NotOnly) = Only() +Bijectors.with_logabsdet_jacobian(::NotOnly, y::T) where {T<:Real} = ([y], zero(T)) +Bijectors.with_logabsdet_jacobian(::NotOnly, y) = ([y], zero(LogProbType)) +struct ExpOnly{L<:Real} + lower::L +end +(e::ExpOnly)(y::AbstractVector{<:Real}) = exp(y[]) + e.lower +function Bijectors.with_logabsdet_jacobian(e::ExpOnly, y::AbstractVector{<:Real}) + yi = y[] + x = exp(yi) + return (x + e.lower, yi) +end +InverseFunctions.inverse(e::ExpOnly) = LogVect(e.lower) +struct LogVect{L<:Real} + lower::L +end +(l::LogVect)(x::Real) = [log(x - l.lower)] +function Bijectors.with_logabsdet_jacobian(l::LogVect, x::Real) + logx = log(x - l.lower) + return ([logx], -logx) +end +InverseFunctions.inverse(l::LogVect) = ExpOnly(l.lower) +struct TruncateOnly{L<:Real,U<:Real} + lower::L + upper::U +end +function (t::TruncateOnly)(y::AbstractVector{<:Real}) + lbounded, ubounded = isfinite(t.lower), isfinite(t.upper) + return if lbounded && ubounded + ((t.upper - t.lower) * LogExpFunctions.logistic(y[])) + t.lower + elseif lbounded + exp(y[]) + t.lower + elseif ubounded + t.upper - exp(y[]) + else + y[] + end +end +function Bijectors.with_logabsdet_jacobian( + t::TruncateOnly, y::AbstractVector{T} +) where {T<:Real} + lbounded, ubounded = isfinite(t.lower), isfinite(t.upper) + return if lbounded && ubounded + bma = t.upper - t.lower + yi = y[] + res = (bma * LogExpFunctions.logistic(yi)) + t.lower + # TODO: Bijectors uses this: + # absy = abs(yi) + # return log(bma) - absy - (2 * log1pexp(-absy)) + # Check if it's more numerically stable. Don't immediately see a reason why, but I + # assume there's a reason for it. + logjac = log(bma) + yi - (2 * LogExpFunctions.log1pexp(yi)) + res, logjac + elseif lbounded + yi = y[] + exp(yi) + t.lower, yi + elseif ubounded + yi = y[] + t.upper - exp(yi), yi + else + y[], zero(T) + end +end +InverseFunctions.inverse(t::TruncateOnly) = UntruncateVect(t.lower, t.upper) + +struct UntruncateVect{L<:Real,U<:Real} + lower::L + upper::U +end +function (u::UntruncateVect)(x::Real) + lbounded, ubounded = isfinite(u.lower), isfinite(u.upper) + return [ + if lbounded && ubounded + LogExpFunctions.logit((x - u.lower) / (u.upper - u.lower)) + elseif lbounded + log(x - u.lower) + elseif ubounded + log(u.upper - x) + else + x + end, + ] +end +function Bijectors.with_logabsdet_jacobian(u::UntruncateVect, x::Real) + lbounded, ubounded = isfinite(u.lower), isfinite(u.upper) + return if lbounded && ubounded + bma = u.upper - u.lower + xma = x - u.lower + xma_over_bma = xma / bma + [LogExpFunctions.logit(xma_over_bma)], -log(xma_over_bma * (u.upper - x)) + elseif lbounded + log_xma = log(x - u.lower) + [log_xma], -log_xma + elseif ubounded + log_bmx = log(u.upper - x) + [log_bmx], -log_bmx + else + return zero(x) + end +end +InverseFunctions.inverse(u::UntruncateVect) = TruncateOnly(u.lower, u.upper) + +for dist_type in [ + Distributions.Cauchy, + Distributions.Chernoff, + Distributions.Gumbel, + Distributions.JohnsonSU, + Distributions.Laplace, + Distributions.Logistic, + Distributions.NoncentralT, + Distributions.Normal, + Distributions.NormalCanon, + Distributions.NormalInverseGaussian, + Distributions.PGeneralizedGaussian, + Distributions.SkewedExponentialPower, + Distributions.SkewNormal, + Distributions.TDist, +] + @eval begin + from_linked_vec_transform(::$dist_type) = Only() + to_linked_vec_transform(::$dist_type) = NotOnly() + end +end +for dist_type in [ + Distributions.BetaPrime, + Distributions.Chi, + Distributions.Chisq, + Distributions.Erlang, + Distributions.Exponential, + Distributions.FDist, + # Wikipedia's definition of the Frechet distribution allows for a location parameter, + # which could cause its minimum to be nonzero. However, Distributionsistributions.jl's `Frechet` + # does not implement this, so we can lump it in here. + Distributions.Frechet, + Distributions.Gamma, + Distributions.InverseGamma, + Distributions.InverseGaussian, + Distributions.Kolmogorov, + Distributions.Lindley, + Distributions.LogNormal, + Distributions.NoncentralChisq, + Distributions.NoncentralF, + Distributions.Rayleigh, + Distributions.Rician, + Distributions.StudentizedRange, + Distributions.Weibull, +] + @eval begin + from_linked_vec_transform(d::$dist_type) = ExpOnly(minimum(d)) + to_linked_vec_transform(d::$dist_type) = LogVect(minimum(d)) + end +end +function to_linked_vec_transform(d::Distributions.ContinuousUnivariateDistribution) + return UntruncateVect(minimum(d), maximum(d)) +end +function from_linked_vec_transform(d::Distributions.ContinuousUnivariateDistribution) + return TruncateOnly(minimum(d), maximum(d)) +end +from_vec_transform(::Distributions.UnivariateDistribution) = Only() +to_vec_transform(::Distributions.UnivariateDistribution) = NotOnly() +from_linked_vec_transform(::DiscreteUnivariateDistribution) = Only() +to_linked_vec_transform(::DiscreteUnivariateDistribution) = NotOnly() + """ from_vec_transform(x) @@ -377,7 +516,7 @@ Return the transformation from the vector representation of `x` to original repr """ from_vec_transform(x::AbstractArray) = from_vec_transform_for_size(size(x)) from_vec_transform(C::Cholesky) = ToChol(C.uplo) ∘ ReshapeTransform(size(C.UL)) -from_vec_transform(::Real) = UnwrapSingletonTransform() +from_vec_transform(::Real) = Only() """ from_vec_transform_for_size(sz::Tuple) @@ -395,7 +534,6 @@ Return the transformation from the vector representation of a realization from distribution `dist` to the original representation compatible with `dist`. """ from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist)) -from_vec_transform(::UnivariateDistribution) = UnwrapSingletonTransform() from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ∘ ReshapeTransform(size(dist)) struct ProductNamedTupleUnvecTransform{names,T<:NamedTuple{names}} @@ -441,7 +579,7 @@ end # This function returns the length of the vector that the function from_vec_transform # expects. This helps us determine which segment of a concatenated vector belongs to which # variable. -_input_length(from_vec_trfm::UnwrapSingletonTransform) = 1 +_input_length(::Only) = 1 _input_length(from_vec_trfm::ReshapeTransform) = prod(from_vec_trfm.output_size) function _input_length(trfm::ProductNamedTupleUnvecTransform) return sum(_input_length ∘ from_vec_transform, values(trfm.dists)) @@ -477,19 +615,6 @@ function from_linked_vec_transform(dist::Distribution) f_vec = from_vec_transform(inverse(f_invlink), size(dist)) return f_invlink ∘ f_vec end - -# UnivariateDistributions need to be handled as a special case, because size(dist) is (), -# which makes the usual machinery think we are dealing with a 0-dim array, whereas in -# actuality we are dealing with a scalar. -# TODO(mhauru) Hopefully all this can go once the old Gibbs sampler is removed and -# VarNamedVector takes over from Metadata. -function from_linked_vec_transform(dist::UnivariateDistribution) - f_invlink = invlink_transform(dist) - f_vec = from_vec_transform(inverse(f_invlink), size(dist)) - f_combined = f_invlink ∘ f_vec - sz = Bijectors.output_size(f_combined, size(dist)) - return UnwrapSingletonTransform(sz) ∘ f_combined -end function from_linked_vec_transform(dist::Distributions.ProductNamedTupleDistribution) return invlink_transform(dist) end