Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6b46e34
added output_length and output_size to compute output, well, leengths
torfjelde Jun 17, 2023
41d0b06
added tests for size of transformed dist using VcCorrBijector
torfjelde Jun 17, 2023
a76f18a
use already constructed transfrormation
torfjelde Jun 17, 2023
6afc77e
TransformedDistribution should now also have correct variate form
torfjelde Jun 17, 2023
a4c5683
added proper variateform handling for VecCholeskyBijector too
torfjelde Jun 17, 2023
ea724ee
Apply suggestions from code review
torfjelde Jun 17, 2023
387ef5a
added output_size impl for Reshape too
torfjelde Jun 17, 2023
acb5e8f
bump minor version
torfjelde Jun 18, 2023
3391735
Apply suggestions from code review
torfjelde Jun 18, 2023
b524ebb
Update src/interface.jl
torfjelde Jun 18, 2023
d6dc906
Update src/bijectors/corr.jl
torfjelde Jun 18, 2023
280708b
reverted removal of length as we'll need it now
torfjelde Jun 18, 2023
2069d69
updated Stacked to be compat with changing sizes
torfjelde Jun 18, 2023
f533a79
forgot to commit deetion
torfjelde Jun 18, 2023
56b8834
Apply suggestions from code review
torfjelde Jun 18, 2023
098a9c0
add testing of sizes to `test_bijector`
torfjelde Jun 18, 2023
4e14bb2
some more tests for stacked
torfjelde Jun 18, 2023
def7c6f
Update test/bijectors/stacked.jl
torfjelde Jun 18, 2023
fe36875
added awful generated function to determine output ranges for Stacked
torfjelde Jun 18, 2023
bbfaf19
added slightly more informative comment
torfjelde Jun 18, 2023
bf68124
format
torfjelde Jun 18, 2023
45a9850
more fixes to that damned Stacked
torfjelde Jun 18, 2023
1f0c374
Update test/interface.jl
torfjelde Jun 18, 2023
a917c2b
specialized constructors for Stacked further
torfjelde Jun 18, 2023
cdd951a
fixed bug in output_size for CorrVecBijector
torfjelde Jun 18, 2023
5dbd829
Apply suggestions from code review
torfjelde Jun 18, 2023
04f6990
Apply suggestions from code review
torfjelde Jun 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.12.8"
version = "0.13.0"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for the record: this isn't actually a breaking change, but I want to wait with a new release until we've merged this PR + #271 and #263 , i.e. we have have proper support for everything.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the changes to Stacked, this is now indeed a breaking PR.


[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
16 changes: 16 additions & 0 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real})
return _logabsdetjac_inv_corr(y)
end

function output_size(::VecCorrBijector, sz::Tuple{Int,Int})
sz[1] == sz[2] || error("sizes should be equal; received $(sz)")
n = sz[1]
return ((n * (n - 1)) ÷ 2,)
end

function output_size(::Inverse{VecCorrBijector}, sz::Tuple{Int})
n = _triu1_dim_from_length(first(sz))
return (n, n)
end

"""
VecCholeskyBijector <: Bijector

Expand Down Expand Up @@ -317,6 +328,11 @@ function logabsdetjac(::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real})
return _logabsdetjac_inv_chol(y)
end

output_size(::VecCholeskyBijector, sz::Tuple{Int,Int}) = output_size(VecCorrBijector(), sz)
function output_size(::Inverse{<:VecCholeskyBijector}, sz::Tuple{Int})
return output_size(inverse(VecCorrBijector()), sz)
end

"""
function _link_chol_lkj(w)

Expand Down
2 changes: 2 additions & 0 deletions src/bijectors/reshape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,5 @@ end
inverse(b::Reshape) = Reshape(b.out_shape, b.in_shape)

with_logabsdet_jacobian(b::Reshape, x) = reshape(x, b.out_shape), zero(eltype(x))

output_size(b::Reshape, in_size) = b.out_shape
143 changes: 118 additions & 25 deletions src/bijectors/stacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,78 +23,157 @@ b([0.0, 1.0]) == [b1(0.0), 1.0] # => true
"""
struct Stacked{Bs,Rs<:Union{Tuple,AbstractArray}} <: Transform
bs::Bs
ranges::Rs
ranges_in::Rs
ranges_out::Rs
length_in::Int
length_out::Int
end

function Stacked(bs::AbstractArray, ranges_in::AbstractArray)
ranges_out = determine_output_ranges(bs, ranges_in)
return Stacked{typeof(bs),typeof(ranges_in)}(
bs, ranges_in, ranges_out, sum(length, ranges_in), sum(length, ranges_out)
)
end
function Stacked(bs::Tuple, ranges_in::Tuple)
ranges_out = determine_output_ranges(bs, ranges_in)
return Stacked{typeof(bs),typeof(ranges_in)}(
bs, ranges_in, ranges_out, sum(length, ranges_in), sum(length, ranges_out)
)
end
Stacked(bs::AbstractArray, ranges::Tuple) = Stacked(bs, collect(ranges))
Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges)
Stacked(bs::Tuple) = Stacked(bs, ntuple(i -> i:i, length(bs)))
Stacked(bs::AbstractArray) = Stacked(bs, [i:i for i in 1:length(bs)])
Stacked(bs...) = Stacked(bs, ntuple(i -> i:i, length(bs)))

function determine_output_ranges(bs, ranges)
offset = 0
return map(bs, ranges) do b, r
out_length = output_length(b, length(r))
r = offset .+ (1:out_length)
offset += out_length
return r
end
end

# NOTE: I don't like this but it seems necessary because `Stacked(...)` can occur in hot code paths.
function determine_output_ranges(bs::Tuple, ranges::Tuple)
return determine_output_ranges_generated(bs, ranges)
end
@generated function determine_output_ranges_generated(bs::Tuple, ranges::Tuple)
N = length(bs.parameters)
exprs = []
push!(exprs, :(offset = 0))

rsyms = []
for i in 1:N
rsym = Symbol("r_$i")
lengthsym = Symbol("length_$i")
push!(exprs, :($lengthsym = output_length(bs[$i], length(ranges[$i]))))
push!(exprs, :($rsym = offset .+ (1:($lengthsym))))
push!(exprs, :(offset += $lengthsym))

push!(rsyms, rsym)
end

acc_expr = Expr(:tuple, rsyms...)

return quote
$(exprs...)
return $acc_expr
end
end

# Avoid mixing tuples and arrays.
Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges)

Functors.@functor Stacked (bs,)

Base.show(io::IO, b::Stacked) = print(io, "Stacked($(b.bs), $(b.ranges))")
function Base.show(io::IO, b::Stacked)
return print(io, "Stacked($(b.bs), $(b.ranges_in), $(b.ranges_out))")
end

function Base.:(==)(b1::Stacked, b2::Stacked)
bs1, bs2 = b1.bs, b2.bs
if !(bs1 isa Tuple && bs2 isa Tuple || bs1 isa Vector && bs2 isa Vector)
return false
end
return all(bs1 .== bs2) && all(b1.ranges .== b2.ranges)
return all(bs1 .== bs2) &&
all(b1.ranges_in .== b2.ranges_in) &&
all(b1.ranges_out .== b2.ranges_out)
end

isclosedform(b::Stacked) = all(isclosedform, b.bs)

isinvertible(b::Stacked) = all(isinvertible, b.bs)

# For some reason `inverse.(sb.bs)` was unstable... This works though.
inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges)
function inverse(sb::Stacked)
return Stacked(
map(inverse, sb.bs), sb.ranges_out, sb.ranges_in, sb.length_out, sb.length_in
)
end
# map is not type stable for many stacked bijectors as a large tuple
# hence the generated function
@generated function inverse(sb::Stacked{A}) where {A<:Tuple}
exprs = []
for i in 1:length(A.parameters)
push!(exprs, :(inverse(sb.bs[$i])))
end
return :(Stacked(($(exprs...),), sb.ranges))
return :(Stacked(
($(exprs...),), sb.ranges_out, sb.ranges_in, sb.length_out, sb.length_in
))
end

@generated function _transform(x, rs::NTuple{N,UnitRange{Int}}, bs...) where {N}
output_size(b::Stacked, sz::Tuple{Int}) = (b.length_out,)

@generated function _transform_stacked_recursive(
x, rs::NTuple{N,UnitRange{Int}}, bs...
) where {N}
exprs = []
for i in 1:N
push!(exprs, :(bs[$i](x[rs[$i]])))
end
return :(vcat($(exprs...)))
end
function _transform(x, rs::NTuple{1,UnitRange{Int}}, b)
@assert rs[1] == 1:length(x)
function _transform_stacked_recursive(x, rs::NTuple{1,UnitRange{Int}}, b)
rs[1] == 1:length(x) || error("range must be 1:length(x)")
return b(x)
end
function transform(sb::Stacked{<:Tuple,<:Tuple}, x::AbstractVector{<:Real})
y = _transform(x, sb.ranges, sb.bs...)
@assert size(y) == size(x) "x is size $(size(x)) but y is $(size(y))"
function _transform_stacked(sb::Stacked{<:Tuple,<:Tuple}, x::AbstractVector{<:Real})
y = _transform_stacked_recursive(x, sb.ranges_in, sb.bs...)
return y
end
# The Stacked{<:AbstractArray} version is not TrackedArray friendly
function transform(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real})
function _transform_stacked(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real})
N = length(sb.bs)
N == 1 && return sb.bs[1](x[sb.ranges[1]])
N == 1 && return sb.bs[1](x[sb.ranges_in[1]])

y = mapvcat(1:N) do i
sb.bs[i](x[sb.ranges[i]])
sb.bs[i](x[sb.ranges_in[i]])
end
return y
end

function transform(sb::Stacked, x::AbstractVector{<:Real})
if sb.length_in != length(x)
error("input length mismatch ($(sb.length_in) != $(length(x)))")
end
y = _transform_stacked(sb, x)
if sb.length_out != length(y)
error("output length mismatch ($(sb.length_out) != $(length(y)))")
end
@assert size(y) == size(x) "x is size $(size(x)) but y is $(size(y))"
return y
end

function logabsdetjac(b::Stacked, x::AbstractVector{<:Real})
N = length(b.bs)
init = sum(logabsdetjac(b.bs[1], x[b.ranges[1]]))
init = sum(logabsdetjac(b.bs[1], x[b.ranges_in[1]]))

return if N > 1
init + sum(2:N) do i
sum(logabsdetjac(b.bs[i], x[b.ranges[i]]))
sum(logabsdetjac(b.bs[i], x[b.ranges_in[i]]))
end
else
init
Expand All @@ -104,13 +183,13 @@ end
function logabsdetjac(
b::Stacked{<:NTuple{N,Any},<:NTuple{N,Any}}, x::AbstractVector{<:Real}
) where {N}
init = sum(logabsdetjac(b.bs[1], x[b.ranges[1]]))
init = sum(logabsdetjac(b.bs[1], x[b.ranges_in[1]]))

return if N == 1
init
else
init + sum(2:N) do i
sum(logabsdetjac(b.bs[i], x[b.ranges[i]]))
sum(logabsdetjac(b.bs[i], x[b.ranges_in[i]]))
end
end
end
Expand All @@ -124,21 +203,23 @@ end
# logjac += sum(_logjac)
# return (vcat(y_1, y_2), logjac)
# end
@generated function with_logabsdet_jacobian(
@generated function _with_logabsdet_jacobian(
b::Stacked{<:NTuple{N,Any},<:NTuple{N,Any}}, x::AbstractVector
) where {N}
expr = Expr(:block)
y_names = []

push!(expr.args, :((y_1, _logjac) = with_logabsdet_jacobian(b.bs[1], x[b.ranges[1]])))
push!(
expr.args, :((y_1, _logjac) = with_logabsdet_jacobian(b.bs[1], x[b.ranges_in[1]]))
)
# TODO: drop the `sum` when we have dimensionality
push!(expr.args, :(logjac = sum(_logjac)))
push!(y_names, :y_1)
for i in 2:N
y_name = Symbol("y_$i")
push!(
expr.args,
:(($y_name, _logjac) = with_logabsdet_jacobian(b.bs[$i], x[b.ranges[$i]])),
:(($y_name, _logjac) = with_logabsdet_jacobian(b.bs[$i], x[b.ranges_in[$i]])),
)

# TODO: drop the `sum` when we have dimensionality
Expand All @@ -151,14 +232,26 @@ end
return expr
end

function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector)
function _with_logabsdet_jacobian(sb::Stacked, x::AbstractVector)
N = length(sb.bs)
yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges[1]])
yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges_in[1]])
logjac = sum(linit)
ys = mapreduce(vcat, sb.bs[2:end], sb.ranges[2:end]; init=yinit) do b, r
ys = mapreduce(vcat, sb.bs[2:end], sb.ranges_in[2:end]; init=yinit) do b, r
y, l = with_logabsdet_jacobian(b, x[r])
logjac += sum(l)
y
end
return (ys, logjac)
end

function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector)
if sb.length_in != length(x)
error("input length mismatch ($(sb.length_in) != $(length(x)))")
end
y, logjac = _with_logabsdet_jacobian(sb, x)
if output_length(sb, length(x)) != length(y)
error("output length mismatch ($(output_length(sb, length(x))) != $(length(y)))")
end

return (y, logjac)
end
14 changes: 14 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@ function logabsdetjac(f::Columnwise, x::AbstractMatrix)
end
with_logabsdet_jacobian(f::Columnwise, x::AbstractMatrix) = (f(x), logabsdetjac(f, x))

"""
output_size(f, sz)

Returns the output size of `f` given the input size `sz`.
"""
output_size(f, sz) = sz

"""
output_length(f, len::Int)

Returns the output length of `f` given the input length `len`.
"""
output_length(f, len::Int) = only(output_size(f, (len,)))

######################
# Bijector interface #
######################
Expand Down
27 changes: 13 additions & 14 deletions src/transformed_distribution.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
function variateform(d::Distribution, b)
sz_in = size(d)
sz_out = output_size(b, sz_in)
return ArrayLikeVariate{length(sz_out)}
end

variateform(::MultivariateDistribution, ::Inverse{VecCholeskyBijector}) = CholeskyVariate

# Transformed distributions
struct TransformedDistribution{D,B,V} <:
Distribution{V,Continuous} where {D<:Distribution{V,Continuous},B}
Distribution{V,Continuous} where {D<:ContinuousDistribution,B}
dist::D
transform::B

function TransformedDistribution(d::UnivariateDistribution, b)
return new{typeof(d),typeof(b),Univariate}(d, b)
end
function TransformedDistribution(d::MultivariateDistribution, b)
return new{typeof(d),typeof(b),Multivariate}(d, b)
end
function TransformedDistribution(d::MatrixDistribution, b)
return new{typeof(d),typeof(b),Matrixvariate}(d, b)
end
function TransformedDistribution(d::Distribution{CholeskyVariate}, b)
return new{typeof(d),typeof(b),CholeskyVariate}(d, b)
function TransformedDistribution(d::ContinuousDistribution, b)
return new{typeof(d),typeof(b),variateform(d, b)}(d, b)
end
end

Expand Down Expand Up @@ -101,8 +100,8 @@ end
##############################

# size
Base.length(td::Transformed) = length(td.dist)
Base.size(td::Transformed) = size(td.dist)
Base.length(td::Transformed) = prod(output_size(td.transform, size(td.dist)))
Base.size(td::Transformed) = output_size(td.transform, size(td.dist))

function logpdf(td::UnivariateTransformed, y::Real)
x, logjac = with_logabsdet_jacobian(inverse(td.transform), y)
Expand Down
18 changes: 18 additions & 0 deletions test/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector
test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false)

test_ad(x -> sum(bvec(bvecinv(x))), yvec)

# Check that output sizes are computed correctly.
tdist = transformed(dist)
@test length(tdist) == length(yvec)
@test tdist isa MultivariateDistribution

dist_unconstrained = transformed(MvNormal(zeros(length(tdist)), I), inverse(bvec))
@test size(dist_unconstrained) == size(x)
@test dist_unconstrained isa MatrixDistribution
end
end

Expand Down Expand Up @@ -60,6 +69,15 @@ end
# test_bijector is commented out for now,
# as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky)
# test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false)

# Check that output sizes are computed correctly.
tdist = transformed(dist)
@test length(tdist) == length(y)
@test tdist isa MultivariateDistribution

dist_unconstrained = transformed(MvNormal(zeros(length(tdist)), I), inverse(b))
@test size(dist_unconstrained) == size(x)
@test dist_unconstrained isa Distribution{CholeskyVariate,Continuous}
end
end
end
Loading