Skip to content

Commit 7bf80bf

Browse files
committed
just port all the code i wrote
1 parent 72a123a commit 7bf80bf

File tree

2 files changed

+157
-4
lines changed

2 files changed

+157
-4
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1919
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
2020
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2121
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
22+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
2223
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2324
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2425
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -66,6 +67,7 @@ JET = "0.9, 0.10, 0.11"
6667
KernelAbstractions = "0.9.33"
6768
LinearAlgebra = "1.6"
6869
LogDensityProblems = "2"
70+
LogExpFunctions = "0.3.29"
6971
MCMCChains = "6, 7"
7072
MacroTools = "0.5.6"
7173
MarginalLogDensities = "0.4.3"

src/utils.jl

Lines changed: 155 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using LogExpFunctions: LogExpFunctions
2+
13
# singleton for indicating if no default arguments are present
24
struct NoDefault end
35
const NO_DEFAULT = NoDefault()
@@ -345,6 +347,159 @@ Bijectors.with_logabsdet_jacobian(::Only, x::AbstractVector) = (x[], zero(LogPro
345347
Bijectors.inverse(::Only) = NotOnly()
346348
Bijectors.with_logabsdet_jacobian(::NotOnly, y::T) where {T<:Real} = ([y], zero(T))
347349
Bijectors.with_logabsdet_jacobian(::NotOnly, y) = ([y], zero(LogProbType))
350+
struct ExpOnly{L<:Real}
351+
lower::L
352+
end
353+
(e::ExpOnly)(y::AbstractVector{<:Real}) = exp(y[]) + e.lower
354+
function Bijectors.with_logabsdet_jacobian(e::ExpOnly, y::AbstractVector{<:Real})
355+
yi = y[]
356+
x = exp(yi)
357+
return (x + e.lower, yi)
358+
end
359+
Bijectors.inverse(e::ExpOnly) = LogVect(e.lower)
360+
struct LogVect{L<:Real}
361+
lower::L
362+
end
363+
(l::LogVect)(x::Real) = [log(x - l.lower)]
364+
function Bijectors.with_logabsdet_jacobian(l::LogVect, x::Real)
365+
logx = log(x - l.lower)
366+
return ([logx], -logx)
367+
end
368+
Bijectors.inverse(l::LogVect) = ExpOnly(l.lower)
369+
struct TruncateOnly{L<:Real,U<:Real}
370+
lower::L
371+
upper::U
372+
end
373+
function (t::TruncateOnly)(y::AbstractVector{<:Real})
374+
lbounded, ubounded = isfinite(t.lower), isfinite(t.upper)
375+
return if lbounded && ubounded
376+
((t.upper - t.lower) * LogExpFunctions.logistic(y[])) + t.lower
377+
elseif lbounded
378+
exp(y[]) + t.lower
379+
elseif ubounded
380+
t.upper - exp(y[])
381+
else
382+
y[]
383+
end
384+
end
385+
function with_logabsdet_jacobian(t::TruncateOnly, y::AbstractVector{T}) where {T<:Real}
386+
lbounded, ubounded = isfinite(t.lower), isfinite(t.upper)
387+
return if lbounded && ubounded
388+
bma = t.upper - t.lower
389+
yi = y[]
390+
res = (bma * LogExpFunctions.logistic(yi)) + t.lower
391+
# TODO: Bijectors uses this:
392+
# absy = abs(yi)
393+
# return log(bma) - absy - (2 * log1pexp(-absy))
394+
# Check if it's more numerically stable. Don't immediately see a reason why, but I
395+
# assume there's a reason for it.
396+
logjac = log(bma) + yi - (2 * LogExpFunctions.log1pexp(yi))
397+
res, logjac
398+
elseif lbounded
399+
yi = y[]
400+
exp(yi) + t.lower, yi
401+
elseif ubounded
402+
yi = y[]
403+
t.upper - exp(yi), yi
404+
else
405+
y[], zero(T)
406+
end
407+
end
408+
inverse(t::TruncateOnly) = UntruncateVect(t.lower, t.upper)
409+
410+
struct UntruncateVect{L<:Real,U<:Real}
411+
lower::L
412+
upper::U
413+
end
414+
function (u::UntruncateVect)(x::Real)
415+
lbounded, ubounded = isfinite(u.lower), isfinite(u.upper)
416+
return [
417+
if lbounded && ubounded
418+
LogExpFunctions.logit((x - u.lower) / (u.upper - u.lower))
419+
elseif lbounded
420+
log(x - u.lower)
421+
elseif ubounded
422+
log(u.upper - x)
423+
else
424+
x
425+
end,
426+
]
427+
end
428+
function Bijectors.with_logabsdet_jacobian(u::UntruncateVect, x::Real)
429+
lbounded, ubounded = isfinite(u.lower), isfinite(u.upper)
430+
return if lbounded && ubounded
431+
bma = u.upper - u.lower
432+
xma = x - u.lower
433+
xma_over_bma = xma / bma
434+
[LogExpFunctions.logit(xma_over_bma)], -log(xma_over_bma * (u.upper - x))
435+
elseif lbounded
436+
log_xma = log(x - u.lower)
437+
[log_xma], -log_xma
438+
elseif ubounded
439+
log_bmx = log(u.upper - x)
440+
[log_bmx], -log_bmx
441+
else
442+
return zero(x)
443+
end
444+
end
445+
Bijectors.inverse(u::UntruncateVect) = TruncateOnly(u.lower, u.upper)
446+
447+
for dist_type in [
448+
Distributions.Cauchy,
449+
Distributions.Chernoff,
450+
Distributions.Gumbel,
451+
Distributions.JohnsonSU,
452+
Distributions.Laplace,
453+
Distributions.Logistic,
454+
Distributions.NoncentralT,
455+
Distributions.Normal,
456+
Distributions.NormalCanon,
457+
Distributions.NormalInverseGaussian,
458+
Distributions.PGeneralizedGaussian,
459+
Distributions.SkewedExponentialPower,
460+
Distributions.SkewNormal,
461+
Distributions.TDist,
462+
]
463+
@eval begin
464+
from_linked_vec_transform(::$dist_type) = Only()
465+
to_linked_vec_transform(::$dist_type) = NotOnly()
466+
end
467+
end
468+
for dist_type in [
469+
Distributions.BetaPrime,
470+
Distributions.Chi,
471+
Distributions.Chisq,
472+
Distributions.Erlang,
473+
Distributions.Exponential,
474+
Distributions.FDist,
475+
# Wikipedia's definition of the Frechet distribution allows for a location parameter,
476+
# which could cause its minimum to be nonzero. However, Distributionsistributions.jl's `Frechet`
477+
# does not implement this, so we can lump it in here.
478+
Distributions.Frechet,
479+
Distributions.Gamma,
480+
Distributions.InverseGamma,
481+
Distributions.InverseGaussian,
482+
Distributions.Kolmogorov,
483+
Distributions.Lindley,
484+
Distributions.LogNormal,
485+
Distributions.NoncentralChisq,
486+
Distributions.NoncentralF,
487+
Distributions.Rayleigh,
488+
Distributions.Rician,
489+
Distributions.StudentizedRange,
490+
Distributions.Weibull,
491+
]
492+
@eval begin
493+
from_linked_vec_transform(d::$dist_type) = ExpOnly(minimum(d))
494+
to_linked_vec_transform(d::$dist_type) = LogVect(minimum(d))
495+
end
496+
end
497+
function to_linked_vec_transform(d::Distributions.ContinuousUnivariateDistribution)
498+
return UntruncateVect(minimum(d), maximum(d))
499+
end
500+
function from_linked_vec_transform(d::Distributions.ContinuousUnivariateDistribution)
501+
return TruncateOnly(minimum(d), maximum(d))
502+
end
348503

349504
"""
350505
from_vec_transform(x)
@@ -453,10 +608,6 @@ function from_linked_vec_transform(dist::Distribution)
453608
f_vec = from_vec_transform(inverse(f_invlink), size(dist))
454609
return f_invlink f_vec
455610
end
456-
function from_linked_vec_transform(dist::UnivariateDistribution)
457-
# This is a performance optimisation
458-
return Only() invlink_transform(dist)
459-
end
460611
function from_linked_vec_transform(dist::Distributions.ProductNamedTupleDistribution)
461612
return invlink_transform(dist)
462613
end

0 commit comments

Comments
 (0)