|
| 1 | +using LogExpFunctions: LogExpFunctions |
| 2 | + |
1 | 3 | # singleton for indicating if no default arguments are present |
2 | 4 | struct NoDefault end |
3 | 5 | const NO_DEFAULT = NoDefault() |
@@ -345,6 +347,159 @@ Bijectors.with_logabsdet_jacobian(::Only, x::AbstractVector) = (x[], zero(LogPro |
345 | 347 | Bijectors.inverse(::Only) = NotOnly() |
346 | 348 | Bijectors.with_logabsdet_jacobian(::NotOnly, y::T) where {T<:Real} = ([y], zero(T)) |
347 | 349 | Bijectors.with_logabsdet_jacobian(::NotOnly, y) = ([y], zero(LogProbType)) |
| 350 | +struct ExpOnly{L<:Real} |
| 351 | + lower::L |
| 352 | +end |
| 353 | +(e::ExpOnly)(y::AbstractVector{<:Real}) = exp(y[]) + e.lower |
| 354 | +function Bijectors.with_logabsdet_jacobian(e::ExpOnly, y::AbstractVector{<:Real}) |
| 355 | + yi = y[] |
| 356 | + x = exp(yi) |
| 357 | + return (x + e.lower, yi) |
| 358 | +end |
| 359 | +Bijectors.inverse(e::ExpOnly) = LogVect(e.lower) |
| 360 | +struct LogVect{L<:Real} |
| 361 | + lower::L |
| 362 | +end |
| 363 | +(l::LogVect)(x::Real) = [log(x - l.lower)] |
| 364 | +function Bijectors.with_logabsdet_jacobian(l::LogVect, x::Real) |
| 365 | + logx = log(x - l.lower) |
| 366 | + return ([logx], -logx) |
| 367 | +end |
| 368 | +Bijectors.inverse(l::LogVect) = ExpOnly(l.lower) |
| 369 | +struct TruncateOnly{L<:Real,U<:Real} |
| 370 | + lower::L |
| 371 | + upper::U |
| 372 | +end |
| 373 | +function (t::TruncateOnly)(y::AbstractVector{<:Real}) |
| 374 | + lbounded, ubounded = isfinite(t.lower), isfinite(t.upper) |
| 375 | + return if lbounded && ubounded |
| 376 | + ((t.upper - t.lower) * LogExpFunctions.logistic(y[])) + t.lower |
| 377 | + elseif lbounded |
| 378 | + exp(y[]) + t.lower |
| 379 | + elseif ubounded |
| 380 | + t.upper - exp(y[]) |
| 381 | + else |
| 382 | + y[] |
| 383 | + end |
| 384 | +end |
| 385 | +function with_logabsdet_jacobian(t::TruncateOnly, y::AbstractVector{T}) where {T<:Real} |
| 386 | + lbounded, ubounded = isfinite(t.lower), isfinite(t.upper) |
| 387 | + return if lbounded && ubounded |
| 388 | + bma = t.upper - t.lower |
| 389 | + yi = y[] |
| 390 | + res = (bma * LogExpFunctions.logistic(yi)) + t.lower |
| 391 | + # TODO: Bijectors uses this: |
| 392 | + # absy = abs(yi) |
| 393 | + # return log(bma) - absy - (2 * log1pexp(-absy)) |
| 394 | + # Check if it's more numerically stable. Don't immediately see a reason why, but I |
| 395 | + # assume there's a reason for it. |
| 396 | + logjac = log(bma) + yi - (2 * LogExpFunctions.log1pexp(yi)) |
| 397 | + res, logjac |
| 398 | + elseif lbounded |
| 399 | + yi = y[] |
| 400 | + exp(yi) + t.lower, yi |
| 401 | + elseif ubounded |
| 402 | + yi = y[] |
| 403 | + t.upper - exp(yi), yi |
| 404 | + else |
| 405 | + y[], zero(T) |
| 406 | + end |
| 407 | +end |
| 408 | +inverse(t::TruncateOnly) = UntruncateVect(t.lower, t.upper) |
| 409 | + |
| 410 | +struct UntruncateVect{L<:Real,U<:Real} |
| 411 | + lower::L |
| 412 | + upper::U |
| 413 | +end |
| 414 | +function (u::UntruncateVect)(x::Real) |
| 415 | + lbounded, ubounded = isfinite(u.lower), isfinite(u.upper) |
| 416 | + return [ |
| 417 | + if lbounded && ubounded |
| 418 | + LogExpFunctions.logit((x - u.lower) / (u.upper - u.lower)) |
| 419 | + elseif lbounded |
| 420 | + log(x - u.lower) |
| 421 | + elseif ubounded |
| 422 | + log(u.upper - x) |
| 423 | + else |
| 424 | + x |
| 425 | + end, |
| 426 | + ] |
| 427 | +end |
| 428 | +function Bijectors.with_logabsdet_jacobian(u::UntruncateVect, x::Real) |
| 429 | + lbounded, ubounded = isfinite(u.lower), isfinite(u.upper) |
| 430 | + return if lbounded && ubounded |
| 431 | + bma = u.upper - u.lower |
| 432 | + xma = x - u.lower |
| 433 | + xma_over_bma = xma / bma |
| 434 | + [LogExpFunctions.logit(xma_over_bma)], -log(xma_over_bma * (u.upper - x)) |
| 435 | + elseif lbounded |
| 436 | + log_xma = log(x - u.lower) |
| 437 | + [log_xma], -log_xma |
| 438 | + elseif ubounded |
| 439 | + log_bmx = log(u.upper - x) |
| 440 | + [log_bmx], -log_bmx |
| 441 | + else |
| 442 | + return zero(x) |
| 443 | + end |
| 444 | +end |
| 445 | +Bijectors.inverse(u::UntruncateVect) = TruncateOnly(u.lower, u.upper) |
| 446 | + |
| 447 | +for dist_type in [ |
| 448 | + Distributions.Cauchy, |
| 449 | + Distributions.Chernoff, |
| 450 | + Distributions.Gumbel, |
| 451 | + Distributions.JohnsonSU, |
| 452 | + Distributions.Laplace, |
| 453 | + Distributions.Logistic, |
| 454 | + Distributions.NoncentralT, |
| 455 | + Distributions.Normal, |
| 456 | + Distributions.NormalCanon, |
| 457 | + Distributions.NormalInverseGaussian, |
| 458 | + Distributions.PGeneralizedGaussian, |
| 459 | + Distributions.SkewedExponentialPower, |
| 460 | + Distributions.SkewNormal, |
| 461 | + Distributions.TDist, |
| 462 | +] |
| 463 | + @eval begin |
| 464 | + from_linked_vec_transform(::$dist_type) = Only() |
| 465 | + to_linked_vec_transform(::$dist_type) = NotOnly() |
| 466 | + end |
| 467 | +end |
| 468 | +for dist_type in [ |
| 469 | + Distributions.BetaPrime, |
| 470 | + Distributions.Chi, |
| 471 | + Distributions.Chisq, |
| 472 | + Distributions.Erlang, |
| 473 | + Distributions.Exponential, |
| 474 | + Distributions.FDist, |
| 475 | + # Wikipedia's definition of the Frechet distribution allows for a location parameter, |
| 476 | + # which could cause its minimum to be nonzero. However, Distributionsistributions.jl's `Frechet` |
| 477 | + # does not implement this, so we can lump it in here. |
| 478 | + Distributions.Frechet, |
| 479 | + Distributions.Gamma, |
| 480 | + Distributions.InverseGamma, |
| 481 | + Distributions.InverseGaussian, |
| 482 | + Distributions.Kolmogorov, |
| 483 | + Distributions.Lindley, |
| 484 | + Distributions.LogNormal, |
| 485 | + Distributions.NoncentralChisq, |
| 486 | + Distributions.NoncentralF, |
| 487 | + Distributions.Rayleigh, |
| 488 | + Distributions.Rician, |
| 489 | + Distributions.StudentizedRange, |
| 490 | + Distributions.Weibull, |
| 491 | +] |
| 492 | + @eval begin |
| 493 | + from_linked_vec_transform(d::$dist_type) = ExpOnly(minimum(d)) |
| 494 | + to_linked_vec_transform(d::$dist_type) = LogVect(minimum(d)) |
| 495 | + end |
| 496 | +end |
| 497 | +function to_linked_vec_transform(d::Distributions.ContinuousUnivariateDistribution) |
| 498 | + return UntruncateVect(minimum(d), maximum(d)) |
| 499 | +end |
| 500 | +function from_linked_vec_transform(d::Distributions.ContinuousUnivariateDistribution) |
| 501 | + return TruncateOnly(minimum(d), maximum(d)) |
| 502 | +end |
348 | 503 |
|
349 | 504 | """ |
350 | 505 | from_vec_transform(x) |
@@ -453,10 +608,6 @@ function from_linked_vec_transform(dist::Distribution) |
453 | 608 | f_vec = from_vec_transform(inverse(f_invlink), size(dist)) |
454 | 609 | return f_invlink ∘ f_vec |
455 | 610 | end |
456 | | -function from_linked_vec_transform(dist::UnivariateDistribution) |
457 | | - # This is a performance optimisation |
458 | | - return Only() ∘ invlink_transform(dist) |
459 | | -end |
460 | 611 | function from_linked_vec_transform(dist::Distributions.ProductNamedTupleDistribution) |
461 | 612 | return invlink_transform(dist) |
462 | 613 | end |
|
0 commit comments