Skip to content

Commit acff274

Browse files
committed
just port all the code i wrote
1 parent 72a123a commit acff274

File tree

2 files changed

+168
-6
lines changed

2 files changed

+168
-6
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@ DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1717
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1818
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1919
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
20+
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
2021
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2122
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
23+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
2224
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2325
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2426
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -62,10 +64,12 @@ DocStringExtensions = "0.9"
6264
EnzymeCore = "0.6 - 0.8"
6365
ForwardDiff = "0.10.12, 1"
6466
InteractiveUtils = "1"
67+
InverseFunctions = "0.1.17"
6568
JET = "0.9, 0.10, 0.11"
6669
KernelAbstractions = "0.9.33"
6770
LinearAlgebra = "1.6"
6871
LogDensityProblems = "2"
72+
LogExpFunctions = "0.3.29"
6973
MCMCChains = "6, 7"
7074
MacroTools = "0.5.6"
7175
MarginalLogDensities = "0.4.3"

src/utils.jl

Lines changed: 164 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
using LogExpFunctions: LogExpFunctions
2+
using InverseFunctions: InverseFunctions, inverse
3+
14
# singleton for indicating if no default arguments are present
25
struct NoDefault end
36
const NO_DEFAULT = NoDefault()
@@ -342,9 +345,169 @@ function Bijectors.with_logabsdet_jacobian(::Only, x::AbstractVector{T}) where {
342345
return (x[], zero(T))
343346
end
344347
Bijectors.with_logabsdet_jacobian(::Only, x::AbstractVector) = (x[], zero(LogProbType))
345-
Bijectors.inverse(::Only) = NotOnly()
348+
InverseFunctions.inverse(::Only) = NotOnly()
349+
InverseFunctions.inverse(::NotOnly) = Only()
346350
Bijectors.with_logabsdet_jacobian(::NotOnly, y::T) where {T<:Real} = ([y], zero(T))
347351
Bijectors.with_logabsdet_jacobian(::NotOnly, y) = ([y], zero(LogProbType))
352+
struct ExpOnly{L<:Real}
353+
lower::L
354+
end
355+
(e::ExpOnly)(y::AbstractVector{<:Real}) = exp(y[]) + e.lower
356+
function Bijectors.with_logabsdet_jacobian(e::ExpOnly, y::AbstractVector{<:Real})
357+
yi = y[]
358+
x = exp(yi)
359+
return (x + e.lower, yi)
360+
end
361+
InverseFunctions.inverse(e::ExpOnly) = LogVect(e.lower)
362+
struct LogVect{L<:Real}
363+
lower::L
364+
end
365+
(l::LogVect)(x::Real) = [log(x - l.lower)]
366+
function Bijectors.with_logabsdet_jacobian(l::LogVect, x::Real)
367+
logx = log(x - l.lower)
368+
return ([logx], -logx)
369+
end
370+
InverseFunctions.inverse(l::LogVect) = ExpOnly(l.lower)
371+
struct TruncateOnly{L<:Real,U<:Real}
372+
lower::L
373+
upper::U
374+
end
375+
function (t::TruncateOnly)(y::AbstractVector{<:Real})
376+
lbounded, ubounded = isfinite(t.lower), isfinite(t.upper)
377+
return if lbounded && ubounded
378+
((t.upper - t.lower) * LogExpFunctions.logistic(y[])) + t.lower
379+
elseif lbounded
380+
exp(y[]) + t.lower
381+
elseif ubounded
382+
t.upper - exp(y[])
383+
else
384+
y[]
385+
end
386+
end
387+
function Bijectors.with_logabsdet_jacobian(
388+
t::TruncateOnly, y::AbstractVector{T}
389+
) where {T<:Real}
390+
lbounded, ubounded = isfinite(t.lower), isfinite(t.upper)
391+
return if lbounded && ubounded
392+
bma = t.upper - t.lower
393+
yi = y[]
394+
res = (bma * LogExpFunctions.logistic(yi)) + t.lower
395+
# TODO: Bijectors uses this:
396+
# absy = abs(yi)
397+
# return log(bma) - absy - (2 * log1pexp(-absy))
398+
# Check if it's more numerically stable. Don't immediately see a reason why, but I
399+
# assume there's a reason for it.
400+
logjac = log(bma) + yi - (2 * LogExpFunctions.log1pexp(yi))
401+
res, logjac
402+
elseif lbounded
403+
yi = y[]
404+
exp(yi) + t.lower, yi
405+
elseif ubounded
406+
yi = y[]
407+
t.upper - exp(yi), yi
408+
else
409+
y[], zero(T)
410+
end
411+
end
412+
InverseFunctions.inverse(t::TruncateOnly) = UntruncateVect(t.lower, t.upper)
413+
414+
struct UntruncateVect{L<:Real,U<:Real}
415+
lower::L
416+
upper::U
417+
end
418+
function (u::UntruncateVect)(x::Real)
419+
lbounded, ubounded = isfinite(u.lower), isfinite(u.upper)
420+
return [
421+
if lbounded && ubounded
422+
LogExpFunctions.logit((x - u.lower) / (u.upper - u.lower))
423+
elseif lbounded
424+
log(x - u.lower)
425+
elseif ubounded
426+
log(u.upper - x)
427+
else
428+
x
429+
end,
430+
]
431+
end
432+
function Bijectors.with_logabsdet_jacobian(u::UntruncateVect, x::Real)
433+
lbounded, ubounded = isfinite(u.lower), isfinite(u.upper)
434+
return if lbounded && ubounded
435+
bma = u.upper - u.lower
436+
xma = x - u.lower
437+
xma_over_bma = xma / bma
438+
[LogExpFunctions.logit(xma_over_bma)], -log(xma_over_bma * (u.upper - x))
439+
elseif lbounded
440+
log_xma = log(x - u.lower)
441+
[log_xma], -log_xma
442+
elseif ubounded
443+
log_bmx = log(u.upper - x)
444+
[log_bmx], -log_bmx
445+
else
446+
return zero(x)
447+
end
448+
end
449+
InverseFunctions.inverse(u::UntruncateVect) = TruncateOnly(u.lower, u.upper)
450+
451+
for dist_type in [
452+
Distributions.Cauchy,
453+
Distributions.Chernoff,
454+
Distributions.Gumbel,
455+
Distributions.JohnsonSU,
456+
Distributions.Laplace,
457+
Distributions.Logistic,
458+
Distributions.NoncentralT,
459+
Distributions.Normal,
460+
Distributions.NormalCanon,
461+
Distributions.NormalInverseGaussian,
462+
Distributions.PGeneralizedGaussian,
463+
Distributions.SkewedExponentialPower,
464+
Distributions.SkewNormal,
465+
Distributions.TDist,
466+
]
467+
@eval begin
468+
from_linked_vec_transform(::$dist_type) = Only()
469+
to_linked_vec_transform(::$dist_type) = NotOnly()
470+
end
471+
end
472+
for dist_type in [
473+
Distributions.BetaPrime,
474+
Distributions.Chi,
475+
Distributions.Chisq,
476+
Distributions.Erlang,
477+
Distributions.Exponential,
478+
Distributions.FDist,
479+
# Wikipedia's definition of the Frechet distribution allows for a location parameter,
480+
# which could cause its minimum to be nonzero. However, Distributionsistributions.jl's `Frechet`
481+
# does not implement this, so we can lump it in here.
482+
Distributions.Frechet,
483+
Distributions.Gamma,
484+
Distributions.InverseGamma,
485+
Distributions.InverseGaussian,
486+
Distributions.Kolmogorov,
487+
Distributions.Lindley,
488+
Distributions.LogNormal,
489+
Distributions.NoncentralChisq,
490+
Distributions.NoncentralF,
491+
Distributions.Rayleigh,
492+
Distributions.Rician,
493+
Distributions.StudentizedRange,
494+
Distributions.Weibull,
495+
]
496+
@eval begin
497+
from_linked_vec_transform(d::$dist_type) = ExpOnly(minimum(d))
498+
to_linked_vec_transform(d::$dist_type) = LogVect(minimum(d))
499+
end
500+
end
501+
function to_linked_vec_transform(d::Distributions.ContinuousUnivariateDistribution)
502+
return UntruncateVect(minimum(d), maximum(d))
503+
end
504+
function from_linked_vec_transform(d::Distributions.ContinuousUnivariateDistribution)
505+
return TruncateOnly(minimum(d), maximum(d))
506+
end
507+
from_vec_transform(::Distributions.UnivariateDistribution) = Only()
508+
to_vec_transform(::Distributions.UnivariateDistribution) = NotOnly()
509+
from_linked_vec_transform(::DiscreteUnivariateDistribution) = Only()
510+
to_linked_vec_transform(::DiscreteUnivariateDistribution) = NotOnly()
348511

349512
"""
350513
from_vec_transform(x)
@@ -371,7 +534,6 @@ Return the transformation from the vector representation of a realization from
371534
distribution `dist` to the original representation compatible with `dist`.
372535
"""
373536
from_vec_transform(dist::Distribution) = from_vec_transform_for_size(size(dist))
374-
from_vec_transform(::UnivariateDistribution) = Only()
375537
from_vec_transform(dist::LKJCholesky) = ToChol(dist.uplo) ReshapeTransform(size(dist))
376538

377539
struct ProductNamedTupleUnvecTransform{names,T<:NamedTuple{names}}
@@ -453,10 +615,6 @@ function from_linked_vec_transform(dist::Distribution)
453615
f_vec = from_vec_transform(inverse(f_invlink), size(dist))
454616
return f_invlink f_vec
455617
end
456-
function from_linked_vec_transform(dist::UnivariateDistribution)
457-
# This is a performance optimisation
458-
return Only() invlink_transform(dist)
459-
end
460618
function from_linked_vec_transform(dist::Distributions.ProductNamedTupleDistribution)
461619
return invlink_transform(dist)
462620
end

0 commit comments

Comments
 (0)