Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/bijectors/truncated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,28 @@ end

with_logabsdet_jacobian(b::TruncatedBijector, x) = transform(b, x), logabsdetjac(b, x)

function truncated_inv_logabsdetjac(y, a, b)
y, a, b = promote(y, a, b)
lowerbounded, upperbounded = isfinite(a), isfinite(b)
if lowerbounded && upperbounded
abs_y = abs(y)
return log(b - a) - abs_y - 2 * LogExpFunctions.log1pexp(-abs_y)
elseif lowerbounded || upperbounded
return y
else
return zero(y)
end
end

function logabsdetjac(ib::Inverse{<:TruncatedBijector}, y)
a, b = ib.orig.lb, ib.orig.ub
return sum(truncated_inv_logabsdetjac.(y, a, b))
end

function with_logabsdet_jacobian(ib::Inverse{<:TruncatedBijector}, y)
return transform(ib, y), logabsdetjac(ib, y)
end

# It's only monotonically decreasing if it's only upper-bounded.
# In the multivariate case, we can only say something reasonable if entries are monotonic.
function is_monotonically_increasing(b::TruncatedBijector)
Expand Down
2 changes: 1 addition & 1 deletion test/bijectors/ordered.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ end

@testset "correctness" begin
num_samples = 10_000
num_adapts = 1_000
num_adapts = 5_000
@testset "k = $k" for k in [2, 3, 5]
@testset "$(typeof(dist))" for dist in [
# Unconstrained
Expand Down
15 changes: 15 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ contains(predicate::Function, b::Stacked) = any(contains.(predicate, b.bs))
logabsdetjac(inverse(b), y) atol = 1e-6
end
end

@testset "logabsdetjac numerical stability: Bijectors.jl#325" begin
d = Uniform(-1, 1)
b = bijector(d)
y = 80
# x needs higher precision to be calculated correctly, otherwise
# logpdf_with_trans returns -Inf
d_big = Uniform(big(-1.0), big(1.0))
b_big = bijector(d_big)
x_big = inverse(b_big)(big(y))
@test logpdf(d_big, x_big) + logabsdetjacinv(b, y) ≈
logpdf_with_trans(d_big, x_big, true) atol = 1e-14
@test logpdf(d_big, x_big) - logabsdetjac(b, x_big) ≈
logpdf_with_trans(d_big, x_big, true) atol = 1e-14
end
end

@testset "Truncated" begin
Expand Down