-
Notifications
You must be signed in to change notification settings - Fork 41
Make SimplexBijector actually bijective #263
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
36a6b41
fa503cb
4abaae9
579c808
3110270
f21d328
7ba8948
7e2927a
ba21df5
3ce84bb
e8ad6cb
6614b15
776e4af
921f818
d934cfa
8f39b0d
79544ba
8efd243
9c16433
78f015e
144df86
a8e6e21
852c826
97af441
1f8a0f1
5d394bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,20 +1,29 @@ | ||
| #################### | ||
| # Simplex bijector # | ||
| #################### | ||
| struct SimplexBijector{T} <: Bijector end | ||
| SimplexBijector() = SimplexBijector{true}() | ||
| struct SimplexBijector <: Bijector end | ||
|
|
||
| output_size(::SimplexBijector, sz::Tuple{Int}) = (first(sz) - 1,) | ||
| output_size(::Inverse{SimplexBijector}, sz::Tuple{Int}) = (first(sz) + 1,) | ||
|
|
||
| output_size(::SimplexBijector, sz::Tuple{Int,Int}) = Base.setindex(sz, first(sz) - 1, 1) | ||
| output_size(::Inverse{SimplexBijector}, sz::Tuple{Int,Int}) = Base.setindex(sz, first(sz) + 1, 1) | ||
torfjelde marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| with_logabsdet_jacobian(b::SimplexBijector, x) = transform(b, x), logabsdetjac(b, x) | ||
|
|
||
| transform(b::SimplexBijector, x) = _simplex_bijector(x, b) | ||
| transform!(b::SimplexBijector, y, x) = _simplex_bijector!(y, x, b) | ||
|
|
||
| function _simplex_bijector(x::AbstractArray, b::SimplexBijector) | ||
| return _simplex_bijector!(similar(x), x, b) | ||
| sz = size(x) | ||
| K = size(x, 1) | ||
| y = similar(x, Base.setindex(sz, K - 1, 1)) | ||
| _simplex_bijector!(y, x, b) | ||
| return y | ||
| end | ||
|
|
||
| # Vector implementation. | ||
| function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where {proj} | ||
| function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector) | ||
| K = length(x) | ||
| @assert K > 1 "x needs to be of length greater than 1" | ||
| T = eltype(x) | ||
|
|
@@ -29,18 +38,11 @@ function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where | |
| z = (x[k] + ϵ) * (one(T) - 2ϵ) / ((one(T) + ϵ) - sum_tmp) | ||
| y[k] = LogExpFunctions.logit(z) + log(T(K - k)) | ||
| end | ||
| @inbounds sum_tmp += x[K - 1] | ||
| @inbounds if proj | ||
| y[K] = zero(T) | ||
| else | ||
| y[K] = one(T) - sum_tmp - x[K] | ||
| end | ||
|
|
||
| return y | ||
| end | ||
|
|
||
| # Matrix implementation. | ||
| function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where {proj} | ||
| function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector) | ||
| K, N = size(X, 1), size(X, 2) | ||
| @assert K > 1 "x needs to be of length greater than 1" | ||
| T = eltype(X) | ||
|
|
@@ -54,12 +56,6 @@ function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where | |
| z = (X[k, n] + ϵ) * (one(T) - 2ϵ) / ((one(T) + ϵ) - sum_tmp) | ||
| Y[k, n] = LogExpFunctions.logit(z) + log(T(K - k)) | ||
| end | ||
| sum_tmp += X[K - 1, n] | ||
| if proj | ||
| Y[K, n] = zero(T) | ||
| else | ||
| Y[K, n] = one(T) - sum_tmp - X[K, n] | ||
| end | ||
| end | ||
|
|
||
| return Y | ||
|
|
@@ -75,10 +71,16 @@ function transform!( | |
| return _simplex_inv_bijector!(x, y, ib.orig) | ||
| end | ||
|
|
||
| _simplex_inv_bijector(y, b) = _simplex_inv_bijector!(similar(y), y, b) | ||
| function _simplex_inv_bijector(y, b) | ||
| sz = size(y) | ||
| K = sz[1] + 1 | ||
| x = similar(y, Base.setindex(sz, K, 1)) | ||
| _simplex_inv_bijector!(x, y, b) | ||
| return x | ||
| end | ||
|
|
||
| function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj}) where {proj} | ||
| K = length(y) | ||
| function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector) | ||
| K = length(y) + 1 | ||
| @assert K > 1 "x needs to be of length greater than 1" | ||
| T = eltype(y) | ||
| ϵ = _eps(T) | ||
|
|
@@ -91,17 +93,12 @@ function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj}) | |
| x[k] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, 0, 1) | ||
| end | ||
| @inbounds sum_tmp += x[K - 1] | ||
| @inbounds if proj | ||
| x[K] = _clamp(one(T) - sum_tmp, 0, 1) | ||
| else | ||
| x[K] = _clamp(one(T) - sum_tmp - y[K], 0, 1) | ||
| end | ||
|
|
||
| x[K] = _clamp(one(T) - sum_tmp, 0, 1) | ||
| return x | ||
| end | ||
|
|
||
| function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{proj}) where {proj} | ||
| K, N = size(Y, 1), size(Y, 2) | ||
| function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector) | ||
| K, N = size(Y, 1) + 1, size(Y, 2) | ||
| @assert K > 1 "x needs to be of length greater than 1" | ||
| T = eltype(Y) | ||
| ϵ = _eps(T) | ||
|
|
@@ -114,11 +111,7 @@ function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{proj}) | |
| X[k, n] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, 0, 1) | ||
| end | ||
| sum_tmp += X[K - 1, n] | ||
| if proj | ||
| X[K, n] = _clamp(one(T) - sum_tmp, 0, 1) | ||
| else | ||
| X[K, n] = _clamp(one(T) - sum_tmp - Y[K, n], 0, 1) | ||
| end | ||
| X[K, n] = _clamp(one(T) - sum_tmp, 0, 1) | ||
| end | ||
|
|
||
| return X | ||
|
|
@@ -213,13 +206,10 @@ function simplex_logabsdetjac_gradient(x::AbstractMatrix) | |
| return g | ||
| end | ||
|
|
||
| function simplex_link_jacobian( | ||
| x::AbstractVector{T}, ::Val{proj}=Val(true) | ||
| ) where {T<:Real,proj} | ||
| function simplex_link_jacobian(x::AbstractVector{T}) where {T<:Real} | ||
| K = length(x) | ||
| @assert K > 1 "x needs to be of length greater than 1" | ||
| dydxt = similar(x, length(x), length(x)) | ||
| @inbounds dydxt .= 0 | ||
| dydxt = fill!(similar(x, K, K - 1), 0) | ||
| ϵ = _eps(T) | ||
| sum_tmp = zero(T) | ||
|
|
||
|
|
@@ -237,16 +227,10 @@ function simplex_link_jacobian( | |
| ((one(T) + ϵ) - sum_tmp)^2 | ||
| end | ||
| end | ||
| @inbounds sum_tmp += x[K - 1] | ||
| @inbounds if !proj | ||
| @simd for i in 1:K | ||
| dydxt[i, K] = -1 | ||
| end | ||
| end | ||
| return UpperTriangular(dydxt)' | ||
| return dydxt' | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The adjoint operation seems a bit annoying but I guess the algorithm should be updated in separate PRs if desired. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise, agreed, but this should be fixed separately. |
||
| end | ||
| function jacobian(b::SimplexBijector{proj}, x::AbstractVector{T}) where {proj,T} | ||
| return simplex_link_jacobian(x, Val(proj)) | ||
| function jacobian(b::SimplexBijector, x::AbstractVector{T}) where {T} | ||
| return simplex_link_jacobian(x) | ||
| end | ||
|
|
||
| #= | ||
|
|
@@ -315,13 +299,10 @@ function add_simplex_link_adjoint!( | |
| end | ||
| =# | ||
|
|
||
| function simplex_invlink_jacobian( | ||
| y::AbstractVector{T}, ::Val{proj}=Val(true) | ||
| ) where {T<:Real,proj} | ||
| K = length(y) | ||
| function simplex_invlink_jacobian(y::AbstractVector{T}) where {T<:Real} | ||
| K = length(y) + 1 | ||
| @assert K > 1 "x needs to be of length greater than 1" | ||
| dxdy = similar(y, length(y), length(y)) | ||
| @inbounds dxdy .= 0 | ||
| dxdy = fill!(similar(y, K, K - 1), 0) | ||
|
|
||
| ϵ = _eps(T) | ||
| @inbounds z = LogExpFunctions.logistic(y[1] - log(T(K - 1))) | ||
|
|
@@ -346,28 +327,20 @@ function simplex_invlink_jacobian( | |
| end | ||
| end | ||
| @inbounds sum_tmp += clamped_x | ||
| @inbounds if proj | ||
| unclamped_x = one(T) - sum_tmp | ||
| clamped_x = _clamp(unclamped_x, 0, 1) | ||
| else | ||
| unclamped_x = one(T) - sum_tmp - y[K] | ||
| clamped_x = _clamp(unclamped_x, 0, 1) | ||
| if unclamped_x == clamped_x | ||
| dxdy[K, K] = -1 | ||
| end | ||
| end | ||
| unclamped_x = one(T) - sum_tmp | ||
| clamped_x = _clamp(unclamped_x, 0, 1) | ||
| @inbounds if unclamped_x == clamped_x | ||
| for i in 1:(K - 1) | ||
| @simd for j in i:(K - 1) | ||
| dxdy[K, i] += -dxdy[j, i] | ||
| end | ||
| end | ||
| end | ||
| return LowerTriangular(dxdy) | ||
| return dxdy | ||
| end | ||
| # jacobian | ||
| function jacobian(ib::Inverse{<:SimplexBijector{proj}}, y::AbstractVector{T}) where {proj,T} | ||
| return simplex_invlink_jacobian(y, Val(proj)) | ||
| function jacobian(ib::Inverse{<:SimplexBijector}, y::AbstractVector{T}) where {T} | ||
| return simplex_invlink_jacobian(y) | ||
| end | ||
|
|
||
| #= | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.