1+ using LogExpFunctions: LogExpFunctions
2+ using InverseFunctions: InverseFunctions, inverse
3+
14# singleton for indicating if no default arguments are present
25struct NoDefault end
36const NO_DEFAULT = NoDefault ()
@@ -342,9 +345,169 @@ function Bijectors.with_logabsdet_jacobian(::Only, x::AbstractVector{T}) where {
342345 return (x[], zero (T))
343346end
344347Bijectors. with_logabsdet_jacobian (:: Only , x:: AbstractVector ) = (x[], zero (LogProbType))
345- Bijectors. inverse (:: Only ) = NotOnly ()
348+ InverseFunctions. inverse (:: Only ) = NotOnly ()
349+ InverseFunctions. inverse (:: NotOnly ) = Only ()
346350Bijectors. with_logabsdet_jacobian (:: NotOnly , y:: T ) where {T<: Real } = ([y], zero (T))
347351Bijectors. with_logabsdet_jacobian (:: NotOnly , y) = ([y], zero (LogProbType))
352+ struct ExpOnly{L<: Real }
353+ lower:: L
354+ end
355+ (e:: ExpOnly )(y:: AbstractVector{<:Real} ) = exp (y[]) + e. lower
356+ function Bijectors. with_logabsdet_jacobian (e:: ExpOnly , y:: AbstractVector{<:Real} )
357+ yi = y[]
358+ x = exp (yi)
359+ return (x + e. lower, yi)
360+ end
361+ InverseFunctions. inverse (e:: ExpOnly ) = LogVect (e. lower)
362+ struct LogVect{L<: Real }
363+ lower:: L
364+ end
365+ (l:: LogVect )(x:: Real ) = [log (x - l. lower)]
366+ function Bijectors. with_logabsdet_jacobian (l:: LogVect , x:: Real )
367+ logx = log (x - l. lower)
368+ return ([logx], - logx)
369+ end
370+ InverseFunctions. inverse (l:: LogVect ) = ExpOnly (l. lower)
371+ struct TruncateOnly{L<: Real ,U<: Real }
372+ lower:: L
373+ upper:: U
374+ end
375+ function (t:: TruncateOnly )(y:: AbstractVector{<:Real} )
376+ lbounded, ubounded = isfinite (t. lower), isfinite (t. upper)
377+ return if lbounded && ubounded
378+ ((t. upper - t. lower) * LogExpFunctions. logistic (y[])) + t. lower
379+ elseif lbounded
380+ exp (y[]) + t. lower
381+ elseif ubounded
382+ t. upper - exp (y[])
383+ else
384+ y[]
385+ end
386+ end
387+ function Bijectors. with_logabsdet_jacobian (
388+ t:: TruncateOnly , y:: AbstractVector{T}
389+ ) where {T<: Real }
390+ lbounded, ubounded = isfinite (t. lower), isfinite (t. upper)
391+ return if lbounded && ubounded
392+ bma = t. upper - t. lower
393+ yi = y[]
394+ res = (bma * LogExpFunctions. logistic (yi)) + t. lower
395+ # TODO : Bijectors uses this:
396+ # absy = abs(yi)
397+ # return log(bma) - absy - (2 * log1pexp(-absy))
398+ # Check if it's more numerically stable. Don't immediately see a reason why, but I
399+ # assume there's a reason for it.
400+ logjac = log (bma) + yi - (2 * LogExpFunctions. log1pexp (yi))
401+ res, logjac
402+ elseif lbounded
403+ yi = y[]
404+ exp (yi) + t. lower, yi
405+ elseif ubounded
406+ yi = y[]
407+ t. upper - exp (yi), yi
408+ else
409+ y[], zero (T)
410+ end
411+ end
412+ InverseFunctions. inverse (t:: TruncateOnly ) = UntruncateVect (t. lower, t. upper)
413+
414+ struct UntruncateVect{L<: Real ,U<: Real }
415+ lower:: L
416+ upper:: U
417+ end
418+ function (u:: UntruncateVect )(x:: Real )
419+ lbounded, ubounded = isfinite (u. lower), isfinite (u. upper)
420+ return [
421+ if lbounded && ubounded
422+ LogExpFunctions. logit ((x - u. lower) / (u. upper - u. lower))
423+ elseif lbounded
424+ log (x - u. lower)
425+ elseif ubounded
426+ log (u. upper - x)
427+ else
428+ x
429+ end ,
430+ ]
431+ end
432+ function Bijectors. with_logabsdet_jacobian (u:: UntruncateVect , x:: Real )
433+ lbounded, ubounded = isfinite (u. lower), isfinite (u. upper)
434+ return if lbounded && ubounded
435+ bma = u. upper - u. lower
436+ xma = x - u. lower
437+ xma_over_bma = xma / bma
438+ [LogExpFunctions. logit (xma_over_bma)], - log (xma_over_bma * (u. upper - x))
439+ elseif lbounded
440+ log_xma = log (x - u. lower)
441+ [log_xma], - log_xma
442+ elseif ubounded
443+ log_bmx = log (u. upper - x)
444+ [log_bmx], - log_bmx
445+ else
446+ return zero (x)
447+ end
448+ end
449+ InverseFunctions. inverse (u:: UntruncateVect ) = TruncateOnly (u. lower, u. upper)
450+
451+ for dist_type in [
452+ Distributions. Cauchy,
453+ Distributions. Chernoff,
454+ Distributions. Gumbel,
455+ Distributions. JohnsonSU,
456+ Distributions. Laplace,
457+ Distributions. Logistic,
458+ Distributions. NoncentralT,
459+ Distributions. Normal,
460+ Distributions. NormalCanon,
461+ Distributions. NormalInverseGaussian,
462+ Distributions. PGeneralizedGaussian,
463+ Distributions. SkewedExponentialPower,
464+ Distributions. SkewNormal,
465+ Distributions. TDist,
466+ ]
467+ @eval begin
468+ from_linked_vec_transform (:: $dist_type ) = Only ()
469+ to_linked_vec_transform (:: $dist_type ) = NotOnly ()
470+ end
471+ end
472+ for dist_type in [
473+ Distributions. BetaPrime,
474+ Distributions. Chi,
475+ Distributions. Chisq,
476+ Distributions. Erlang,
477+ Distributions. Exponential,
478+ Distributions. FDist,
479+ # Wikipedia's definition of the Frechet distribution allows for a location parameter,
480+ # which could cause its minimum to be nonzero. However, Distributionsistributions.jl's `Frechet`
481+ # does not implement this, so we can lump it in here.
482+ Distributions. Frechet,
483+ Distributions. Gamma,
484+ Distributions. InverseGamma,
485+ Distributions. InverseGaussian,
486+ Distributions. Kolmogorov,
487+ Distributions. Lindley,
488+ Distributions. LogNormal,
489+ Distributions. NoncentralChisq,
490+ Distributions. NoncentralF,
491+ Distributions. Rayleigh,
492+ Distributions. Rician,
493+ Distributions. StudentizedRange,
494+ Distributions. Weibull,
495+ ]
496+ @eval begin
497+ from_linked_vec_transform (d:: $dist_type ) = ExpOnly (minimum (d))
498+ to_linked_vec_transform (d:: $dist_type ) = LogVect (minimum (d))
499+ end
500+ end
501+ function to_linked_vec_transform (d:: Distributions.ContinuousUnivariateDistribution )
502+ return UntruncateVect (minimum (d), maximum (d))
503+ end
504+ function from_linked_vec_transform (d:: Distributions.ContinuousUnivariateDistribution )
505+ return TruncateOnly (minimum (d), maximum (d))
506+ end
507+ from_vec_transform (:: Distributions.UnivariateDistribution ) = Only ()
508+ to_vec_transform (:: Distributions.UnivariateDistribution ) = NotOnly ()
509+ from_linked_vec_transform (:: DiscreteUnivariateDistribution ) = Only ()
510+ to_linked_vec_transform (:: DiscreteUnivariateDistribution ) = NotOnly ()
348511
349512"""
350513 from_vec_transform(x)
@@ -371,7 +534,6 @@ Return the transformation from the vector representation of a realization from
371534distribution `dist` to the original representation compatible with `dist`.
372535"""
373536from_vec_transform (dist:: Distribution ) = from_vec_transform_for_size (size (dist))
374- from_vec_transform (:: UnivariateDistribution ) = Only ()
375537from_vec_transform (dist:: LKJCholesky ) = ToChol (dist. uplo) ∘ ReshapeTransform (size (dist))
376538
377539struct ProductNamedTupleUnvecTransform{names,T<: NamedTuple{names} }
@@ -453,10 +615,6 @@ function from_linked_vec_transform(dist::Distribution)
453615 f_vec = from_vec_transform (inverse (f_invlink), size (dist))
454616 return f_invlink ∘ f_vec
455617end
456- function from_linked_vec_transform (dist:: UnivariateDistribution )
457- # This is a performance optimisation
458- return Only () ∘ invlink_transform (dist)
459- end
460618function from_linked_vec_transform (dist:: Distributions.ProductNamedTupleDistribution )
461619 return invlink_transform (dist)
462620end
0 commit comments