From 6b46e344ccd9fcf66e7edcc9c5f03162884738ec Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 17 Jun 2023 21:49:30 +0100 Subject: [PATCH 01/27] added output_length and output_size to compute output, well, leengths and sizes for transformations --- src/bijectors/corr.jl | 11 +++++++++++ src/interface.jl | 20 ++++++++++++++++++++ src/transformed_distribution.jl | 4 ++-- 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 9367b0cd..c25d0ff3 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -232,6 +232,17 @@ function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) return _logabsdetjac_inv_corr(y) end +function output_size(::VecCorrBijector, sz::NTuple{2}) + @assert sz[1] == sz[2] + n = sz[1] + return (n * (n - 1)) ÷ 2 +end + +function output_size(::Inverse{VecCorrBijector}, sz::NTuple{1}) + n = _triu1_dim_from_length(first(sz)) + return (n, n) +end + """ VecCholeskyBijector <: Bijector diff --git a/src/interface.jl b/src/interface.jl index 91c9a961..68886ce0 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -35,6 +35,26 @@ function logabsdetjac(f::Columnwise, x::AbstractMatrix) end with_logabsdet_jacobian(f::Columnwise, x::AbstractMatrix) = (f(x), logabsdetjac(f, x)) +""" + output_size(f, sz) + +Returns the output size of `f` given the input size `sz`. +""" +output_size(f, sz) = sz + +""" + output_length(f, len::Int) + output_length(f, sz::Tuple) + +Returns the output length of `f` given the input length `len` or size `sz`. +""" +output_length(f, len::Int) = len +function output_length(f, len::Tuple) + sz = output_size(f, len) + @assert length(sz) == 1 + return first(sz) +end + ###################### # Bijector interface # ###################### diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index eccbb64c..b7375899 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -101,8 +101,8 @@ end ############################## # size -Base.length(td::Transformed) = length(td.dist) -Base.size(td::Transformed) = size(td.dist) +Base.length(td::Transformed) = output_length(td.transform, size(td.dist)) +Base.size(td::Transformed) = output_size(td.transform, size(td.dist)) function logpdf(td::UnivariateTransformed, y::Real) x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) From 41d0b063443964028746f005f47f94e9bc1301a0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 17 Jun 2023 21:54:05 +0100 Subject: [PATCH 02/27] added tests for size of transformed dist using VcCorrBijector --- test/bijectors/corr.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 71bfd8d7..32ef8e7e 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -33,6 +33,13 @@ using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false) test_ad(x -> sum(bvec(bvecinv(x))), yvec) + + # Check that output sizes are computed correctly. + dist = transformed(dist) + @test length(dist) == length(yvec) + + dist_unconstrained = transformed(MvNormal(zeros(length(dist)), I), inverse(dist.transform)) + @test size(dist_unconstrained) == size(x) end end From a76f18a8fc06ff6e4b5e3cdc09a92a9e4e37a63b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 17 Jun 2023 21:55:10 +0100 Subject: [PATCH 03/27] use already constructed transfrormation --- test/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 32ef8e7e..9e29bc81 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -38,7 +38,7 @@ using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector dist = transformed(dist) @test length(dist) == length(yvec) - dist_unconstrained = transformed(MvNormal(zeros(length(dist)), I), inverse(dist.transform)) + dist_unconstrained = transformed(MvNormal(zeros(length(dist)), I), inverse(bvec)) @test size(dist_unconstrained) == size(x) end end From 6afc77e8339fe5b5e5153ed14e4026abe3766b48 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 17 Jun 2023 22:15:55 +0100 Subject: [PATCH 04/27] TransformedDistribution should now also have correct variate form --- src/transformed_distribution.jl | 22 +++++++++------------- test/bijectors/corr.jl | 2 ++ 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index b7375899..9622698f 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -1,20 +1,16 @@ +function variateform(d::Distribution{<:ArrayLikeVariate}, b) + sz_in = size(d) + sz_out = output_size(b, sz_in) + return ArrayLikeVariate{length(sz_out)} +end + # Transformed distributions -struct TransformedDistribution{D,B,V} <: - Distribution{V,Continuous} where {D<:Distribution{V,Continuous},B} +struct TransformedDistribution{D,B,V} <: Distribution{V,Continuous} where {D<:ContinuousDistribution,B} dist::D transform::B - function TransformedDistribution(d::UnivariateDistribution, b) - return new{typeof(d),typeof(b),Univariate}(d, b) - end - function TransformedDistribution(d::MultivariateDistribution, b) - return new{typeof(d),typeof(b),Multivariate}(d, b) - end - function TransformedDistribution(d::MatrixDistribution, b) - return new{typeof(d),typeof(b),Matrixvariate}(d, b) - end - function TransformedDistribution(d::Distribution{CholeskyVariate}, b) - return new{typeof(d),typeof(b),CholeskyVariate}(d, b) + function TransformedDistribution(d::ContinuousDistribution, b) + return new{typeof(d),typeof(b),variateform(d,b)}(d, b) end end diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 9e29bc81..0fb29412 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -37,9 +37,11 @@ using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector # Check that output sizes are computed correctly. dist = transformed(dist) @test length(dist) == length(yvec) + @test dist isa MultivariateDistribution dist_unconstrained = transformed(MvNormal(zeros(length(dist)), I), inverse(bvec)) @test size(dist_unconstrained) == size(x) + @test dist_unconstrained isa MatrixDistribution end end From a4c56839fc2159ea91b5a3dd9868b5ab0f594dcf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 17 Jun 2023 22:38:26 +0100 Subject: [PATCH 05/27] added proper variateform handling for VecCholeskyBijector too --- src/bijectors/corr.jl | 3 +++ src/transformed_distribution.jl | 4 +++- test/bijectors/corr.jl | 19 +++++++++++++++---- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index c25d0ff3..89716ea8 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -328,6 +328,9 @@ function logabsdetjac(::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) return _logabsdetjac_inv_chol(y) end +output_size(::VecCholeskyBijector, sz::NTuple{2}) = output_size(VecCorrBijector(), sz) +output_size(::Inverse{<:VecCholeskyBijector}, sz::NTuple{1}) = output_size(inverse(VecCorrBijector()), sz) + """ function _link_chol_lkj(w) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 9622698f..258554f7 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -1,9 +1,11 @@ -function variateform(d::Distribution{<:ArrayLikeVariate}, b) +function variateform(d::Distribution, b) sz_in = size(d) sz_out = output_size(b, sz_in) return ArrayLikeVariate{length(sz_out)} end +variateform(::MultivariateDistribution, ::Inverse{VecCholeskyBijector}) = CholeskyVariate + # Transformed distributions struct TransformedDistribution{D,B,V} <: Distribution{V,Continuous} where {D<:ContinuousDistribution,B} dist::D diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index 0fb29412..d6a4e784 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -35,11 +35,11 @@ using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector test_ad(x -> sum(bvec(bvecinv(x))), yvec) # Check that output sizes are computed correctly. - dist = transformed(dist) - @test length(dist) == length(yvec) - @test dist isa MultivariateDistribution + tdist = transformed(dist) + @test length(tdist) == length(yvec) + @test tdist isa MultivariateDistribution - dist_unconstrained = transformed(MvNormal(zeros(length(dist)), I), inverse(bvec)) + dist_unconstrained = transformed(MvNormal(zeros(length(tdist)), I), inverse(bvec)) @test size(dist_unconstrained) == size(x) @test dist_unconstrained isa MatrixDistribution end @@ -69,6 +69,17 @@ end # test_bijector is commented out for now, # as isapprox is not defined for ::Cholesky types (the domain of LKJCholesky) # test_bijector(b, x; test_not_identity=d != 1, changes_of_variables_test=false) + + # Check that output sizes are computed correctly. + tdist = transformed(dist) + @test length(tdist) == length(y) + @test tdist isa MultivariateDistribution + + dist_unconstrained = transformed( + MvNormal(zeros(length(tdist)), I), inverse(b) + ) + @test size(dist_unconstrained) == size(x) + @test dist_unconstrained isa Distribution{CholeskyVariate,Continuous} end end end From ea724eeced3646483752ae257dfd7cbbd87c5c4f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 17 Jun 2023 22:41:48 +0100 Subject: [PATCH 06/27] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/bijectors/corr.jl | 4 +++- src/transformed_distribution.jl | 5 +++-- test/bijectors/corr.jl | 4 +--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 89716ea8..c9284332 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -329,7 +329,9 @@ function logabsdetjac(::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) end output_size(::VecCholeskyBijector, sz::NTuple{2}) = output_size(VecCorrBijector(), sz) -output_size(::Inverse{<:VecCholeskyBijector}, sz::NTuple{1}) = output_size(inverse(VecCorrBijector()), sz) +function output_size(::Inverse{<:VecCholeskyBijector}, sz::NTuple{1}) + return output_size(inverse(VecCorrBijector()), sz) +end """ function _link_chol_lkj(w) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 258554f7..fd147267 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -7,12 +7,13 @@ end variateform(::MultivariateDistribution, ::Inverse{VecCholeskyBijector}) = CholeskyVariate # Transformed distributions -struct TransformedDistribution{D,B,V} <: Distribution{V,Continuous} where {D<:ContinuousDistribution,B} +struct TransformedDistribution{D,B,V} <: + Distribution{V,Continuous} where {D<:ContinuousDistribution,B} dist::D transform::B function TransformedDistribution(d::ContinuousDistribution, b) - return new{typeof(d),typeof(b),variateform(d,b)}(d, b) + return new{typeof(d),typeof(b),variateform(d, b)}(d, b) end end diff --git a/test/bijectors/corr.jl b/test/bijectors/corr.jl index d6a4e784..8a423bc3 100644 --- a/test/bijectors/corr.jl +++ b/test/bijectors/corr.jl @@ -75,9 +75,7 @@ end @test length(tdist) == length(y) @test tdist isa MultivariateDistribution - dist_unconstrained = transformed( - MvNormal(zeros(length(tdist)), I), inverse(b) - ) + dist_unconstrained = transformed(MvNormal(zeros(length(tdist)), I), inverse(b)) @test size(dist_unconstrained) == size(x) @test dist_unconstrained isa Distribution{CholeskyVariate,Continuous} end From 387ef5a3c028b496ef3c71731dbac6506eb669bf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 00:15:44 +0100 Subject: [PATCH 07/27] added output_size impl for Reshape too --- src/bijectors/reshape.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/bijectors/reshape.jl b/src/bijectors/reshape.jl index 8a8bd1e4..4f8665cd 100644 --- a/src/bijectors/reshape.jl +++ b/src/bijectors/reshape.jl @@ -25,3 +25,5 @@ end inverse(b::Reshape) = Reshape(b.out_shape, b.in_shape) with_logabsdet_jacobian(b::Reshape, x) = reshape(x, b.out_shape), zero(eltype(x)) + +output_size(b::Reshape, in_size) = b.out_shape From acb5e8f000dc03cb97ae2d8fe28c68ef573b5ae0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 02:42:11 +0100 Subject: [PATCH 08/27] bump minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 978d97df..dcf1ce79 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.12.8" +version = "0.13.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" From 339173525f3f345697bee695b8a489b511ee317a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 10:40:44 +0100 Subject: [PATCH 09/27] Apply suggestions from code review Co-authored-by: David Widmann --- src/bijectors/corr.jl | 8 ++++---- src/interface.jl | 4 ++-- src/transformed_distribution.jl | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index c9284332..80419e57 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -232,13 +232,13 @@ function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) return _logabsdetjac_inv_corr(y) end -function output_size(::VecCorrBijector, sz::NTuple{2}) +function output_size(::VecCorrBijector, sz::Tuple{Int,Int}) @assert sz[1] == sz[2] n = sz[1] return (n * (n - 1)) ÷ 2 end -function output_size(::Inverse{VecCorrBijector}, sz::NTuple{1}) +function output_size(::Inverse{VecCorrBijector}, sz::Tuple{Int}) n = _triu1_dim_from_length(first(sz)) return (n, n) end @@ -328,8 +328,8 @@ function logabsdetjac(::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real}) return _logabsdetjac_inv_chol(y) end -output_size(::VecCholeskyBijector, sz::NTuple{2}) = output_size(VecCorrBijector(), sz) -function output_size(::Inverse{<:VecCholeskyBijector}, sz::NTuple{1}) +output_size(::VecCholeskyBijector, sz::Tuple{Int,Int}) = output_size(VecCorrBijector(), sz) +function output_size(::Inverse{<:VecCholeskyBijector}, sz::Tuple{Int}) return output_size(inverse(VecCorrBijector()), sz) end diff --git a/src/interface.jl b/src/interface.jl index 68886ce0..670db19b 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -48,8 +48,8 @@ output_size(f, sz) = sz Returns the output length of `f` given the input length `len` or size `sz`. """ -output_length(f, len::Int) = len -function output_length(f, len::Tuple) +output_length(f, len::Int) = only(output_size(f, (len,))) +output_length(f, len::Tuple{Int,Vararg{Int}}) = only(output_size(f, len)) sz = output_size(f, len) @assert length(sz) == 1 return first(sz) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index fd147267..bc9f2e90 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -100,7 +100,7 @@ end ############################## # size -Base.length(td::Transformed) = output_length(td.transform, size(td.dist)) +Base.length(td::Transformed) = only(output_size(td.transform, size(td.dist))) Base.size(td::Transformed) = output_size(td.transform, size(td.dist)) function logpdf(td::UnivariateTransformed, y::Real) From b524ebb57d4eda573e8f2490d2203fb6a4320fdd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 10:41:42 +0100 Subject: [PATCH 10/27] Update src/interface.jl --- src/interface.jl | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 670db19b..11fc7f55 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -42,19 +42,6 @@ Returns the output size of `f` given the input size `sz`. """ output_size(f, sz) = sz -""" - output_length(f, len::Int) - output_length(f, sz::Tuple) - -Returns the output length of `f` given the input length `len` or size `sz`. -""" -output_length(f, len::Int) = only(output_size(f, (len,))) -output_length(f, len::Tuple{Int,Vararg{Int}}) = only(output_size(f, len)) - sz = output_size(f, len) - @assert length(sz) == 1 - return first(sz) -end - ###################### # Bijector interface # ###################### From d6dc906960db0ca14a7e91ae9b8b372dd73f3faa Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 10:44:38 +0100 Subject: [PATCH 11/27] Update src/bijectors/corr.jl --- src/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 80419e57..179ed4e1 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -233,7 +233,7 @@ function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real}) end function output_size(::VecCorrBijector, sz::Tuple{Int,Int}) - @assert sz[1] == sz[2] + sz[1] == sz[2] || error("sizes should be equal; received $(sz)") n = sz[1] return (n * (n - 1)) ÷ 2 end From 280708b684bc96bbcdb2e641c936ca87bec0bc76 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 11:22:19 +0100 Subject: [PATCH 12/27] reverted removal of length as we'll need it now --- src/interface.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/interface.jl b/src/interface.jl index 11fc7f55..670db19b 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -42,6 +42,19 @@ Returns the output size of `f` given the input size `sz`. """ output_size(f, sz) = sz +""" + output_length(f, len::Int) + output_length(f, sz::Tuple) + +Returns the output length of `f` given the input length `len` or size `sz`. +""" +output_length(f, len::Int) = only(output_size(f, (len,))) +output_length(f, len::Tuple{Int,Vararg{Int}}) = only(output_size(f, len)) + sz = output_size(f, len) + @assert length(sz) == 1 + return first(sz) +end + ###################### # Bijector interface # ###################### From 2069d6900dd8cd21bb5348cf9292d59cdd271fc9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 11:22:34 +0100 Subject: [PATCH 13/27] updated Stacked to be compat with changing sizes --- src/bijectors/stacked.jl | 52 ++++++++++++++++++++++++++-------------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 4f0596cb..7a27f14b 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -23,25 +23,38 @@ b([0.0, 1.0]) == [b1(0.0), 1.0] # => true """ struct Stacked{Bs,Rs<:Union{Tuple,AbstractArray}} <: Transform bs::Bs - ranges::Rs + ranges_in::Rs + ranges_out::Rs end + +Stacked(bs, ranges) = Stacked(bs, ranges, determine_output_ranges(bs, ranges)) Stacked(bs::Tuple) = Stacked(bs, ntuple(i -> i:i, length(bs))) Stacked(bs::AbstractArray) = Stacked(bs, [i:i for i in 1:length(bs)]) Stacked(bs...) = Stacked(bs, ntuple(i -> i:i, length(bs))) +function determine_output_ranges(bs, ranges) + offset = 0 + return map(bs, ranges) do b, r + out_length = output_length(b, length(r)) + r = offset .+ (1:out_length) + offset += out_length + return r + end +end + # Avoid mixing tuples and arrays. Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges) Functors.@functor Stacked (bs,) -Base.show(io::IO, b::Stacked) = print(io, "Stacked($(b.bs), $(b.ranges))") +Base.show(io::IO, b::Stacked) = print(io, "Stacked($(b.bs), $(b.ranges_in), $(b.ranges_out))") function Base.:(==)(b1::Stacked, b2::Stacked) bs1, bs2 = b1.bs, b2.bs if !(bs1 isa Tuple && bs2 isa Tuple || bs1 isa Vector && bs2 isa Vector) return false end - return all(bs1 .== bs2) && all(b1.ranges .== b2.ranges) + return all(bs1 .== bs2) && all(b1.ranges_in .== b2.ranges_in) && all(b1.ranges_out .== b2.ranges_out) end isclosedform(b::Stacked) = all(isclosedform, b.bs) @@ -49,7 +62,7 @@ isclosedform(b::Stacked) = all(isclosedform, b.bs) isinvertible(b::Stacked) = all(isinvertible, b.bs) # For some reason `inverse.(sb.bs)` was unstable... This works though. -inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges) +inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges_out, sb.ranges_in) # map is not type stable for many stacked bijectors as a large tuple # hence the generated function @generated function inverse(sb::Stacked{A}) where {A<:Tuple} @@ -57,7 +70,12 @@ inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges) for i in 1:length(A.parameters) push!(exprs, :(inverse(sb.bs[$i]))) end - return :(Stacked(($(exprs...),), sb.ranges)) + return :(Stacked(($(exprs...),), sb.ranges_out, sb.ranges_in)) +end + +function output_size(b::Stacked, sz::Tuple{Int}) + sz_out = sum(length, b.ranges_out) + return (sz_out,) end @generated function _transform(x, rs::NTuple{N,UnitRange{Int}}, bs...) where {N} @@ -72,29 +90,27 @@ function _transform(x, rs::NTuple{1,UnitRange{Int}}, b) return b(x) end function transform(sb::Stacked{<:Tuple,<:Tuple}, x::AbstractVector{<:Real}) - y = _transform(x, sb.ranges, sb.bs...) - @assert size(y) == size(x) "x is size $(size(x)) but y is $(size(y))" + y = _transform(x, sb.ranges_in, sb.bs...) return y end # The Stacked{<:AbstractArray} version is not TrackedArray friendly function transform(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real}) N = length(sb.bs) - N == 1 && return sb.bs[1](x[sb.ranges[1]]) + N == 1 && return sb.bs[1](x[sb.ranges_in[1]]) y = mapvcat(1:N) do i - sb.bs[i](x[sb.ranges[i]]) + sb.bs[i](x[sb.ranges_in[i]]) end - @assert size(y) == size(x) "x is size $(size(x)) but y is $(size(y))" return y end function logabsdetjac(b::Stacked, x::AbstractVector{<:Real}) N = length(b.bs) - init = sum(logabsdetjac(b.bs[1], x[b.ranges[1]])) + init = sum(logabsdetjac(b.bs[1], x[b.ranges_in[1]])) return if N > 1 init + sum(2:N) do i - sum(logabsdetjac(b.bs[i], x[b.ranges[i]])) + sum(logabsdetjac(b.bs[i], x[b.ranges_in[i]])) end else init @@ -104,13 +120,13 @@ end function logabsdetjac( b::Stacked{<:NTuple{N,Any},<:NTuple{N,Any}}, x::AbstractVector{<:Real} ) where {N} - init = sum(logabsdetjac(b.bs[1], x[b.ranges[1]])) + init = sum(logabsdetjac(b.bs[1], x[b.ranges_in[1]])) return if N == 1 init else init + sum(2:N) do i - sum(logabsdetjac(b.bs[i], x[b.ranges[i]])) + sum(logabsdetjac(b.bs[i], x[b.ranges_in[i]])) end end end @@ -130,7 +146,7 @@ end expr = Expr(:block) y_names = [] - push!(expr.args, :((y_1, _logjac) = with_logabsdet_jacobian(b.bs[1], x[b.ranges[1]]))) + push!(expr.args, :((y_1, _logjac) = with_logabsdet_jacobian(b.bs[1], x[b.ranges_in[1]]))) # TODO: drop the `sum` when we have dimensionality push!(expr.args, :(logjac = sum(_logjac))) push!(y_names, :y_1) @@ -138,7 +154,7 @@ end y_name = Symbol("y_$i") push!( expr.args, - :(($y_name, _logjac) = with_logabsdet_jacobian(b.bs[$i], x[b.ranges[$i]])), + :(($y_name, _logjac) = with_logabsdet_jacobian(b.bs[$i], x[b.ranges_in[$i]])), ) # TODO: drop the `sum` when we have dimensionality @@ -153,9 +169,9 @@ end function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector) N = length(sb.bs) - yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges[1]]) + yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges_in[1]]) logjac = sum(linit) - ys = mapreduce(vcat, sb.bs[2:end], sb.ranges[2:end]; init=yinit) do b, r + ys = mapreduce(vcat, sb.bs[2:end], sb.ranges_in[2:end]; init=yinit) do b, r y, l = with_logabsdet_jacobian(b, x[r]) logjac += sum(l) y From f533a79e6d6e241972172797049bae2ac41940ab Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 11:23:07 +0100 Subject: [PATCH 14/27] forgot to commit deetion --- src/interface.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 670db19b..f8c94a31 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -49,11 +49,6 @@ output_size(f, sz) = sz Returns the output length of `f` given the input length `len` or size `sz`. """ output_length(f, len::Int) = only(output_size(f, (len,))) -output_length(f, len::Tuple{Int,Vararg{Int}}) = only(output_size(f, len)) - sz = output_size(f, len) - @assert length(sz) == 1 - return first(sz) -end ###################### # Bijector interface # From 56b88341a42cf5b35527a410fe86c9fe58d4c0b8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 11:27:44 +0100 Subject: [PATCH 15/27] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/bijectors/stacked.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 7a27f14b..cd1ee499 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -47,14 +47,18 @@ Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges) Functors.@functor Stacked (bs,) -Base.show(io::IO, b::Stacked) = print(io, "Stacked($(b.bs), $(b.ranges_in), $(b.ranges_out))") +function Base.show(io::IO, b::Stacked) + return print(io, "Stacked($(b.bs), $(b.ranges_in), $(b.ranges_out))") +end function Base.:(==)(b1::Stacked, b2::Stacked) bs1, bs2 = b1.bs, b2.bs if !(bs1 isa Tuple && bs2 isa Tuple || bs1 isa Vector && bs2 isa Vector) return false end - return all(bs1 .== bs2) && all(b1.ranges_in .== b2.ranges_in) && all(b1.ranges_out .== b2.ranges_out) + return all(bs1 .== bs2) && + all(b1.ranges_in .== b2.ranges_in) && + all(b1.ranges_out .== b2.ranges_out) end isclosedform(b::Stacked) = all(isclosedform, b.bs) @@ -146,7 +150,9 @@ end expr = Expr(:block) y_names = [] - push!(expr.args, :((y_1, _logjac) = with_logabsdet_jacobian(b.bs[1], x[b.ranges_in[1]]))) + push!( + expr.args, :((y_1, _logjac) = with_logabsdet_jacobian(b.bs[1], x[b.ranges_in[1]])) + ) # TODO: drop the `sum` when we have dimensionality push!(expr.args, :(logjac = sum(_logjac))) push!(y_names, :y_1) From 098a9c0f8ce272c2e5e833a3b3468dd9c0d86594 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 11:31:45 +0100 Subject: [PATCH 16/27] add testing of sizes to `test_bijector` --- test/bijectors/utils.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/bijectors/utils.jl b/test/bijectors/utils.jl index b9ab1242..8c31ccec 100644 --- a/test/bijectors/utils.jl +++ b/test/bijectors/utils.jl @@ -14,6 +14,7 @@ function test_bijector( test_types=false, changes_of_variables_test=true, inverse_functions_test=true, + test_sizes=true, compare=isapprox, kwargs..., ) @@ -31,6 +32,11 @@ function test_bijector( @inferred(with_logabsdet_jacobian(inverse(b), y_test)) end + if test_sizes + @test Bijectors.output_size(b, size(x)) == size(y_test) + @test Bijectors.output_size(ib, size(y_test)) == size(x) + end + # ChangesOfVariables.jl # For non-bijective transformations, these tests always fail since determinant of # the Jacobian is zero. Hence we allow the caller to disable them if necessary. From 4e14bb2b77cacb4ad1dbb431d5db83281d624297 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 11:32:07 +0100 Subject: [PATCH 17/27] some more tests for stacked --- test/bijectors/stacked.jl | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 test/bijectors/stacked.jl diff --git a/test/bijectors/stacked.jl b/test/bijectors/stacked.jl new file mode 100644 index 00000000..4ae3cef2 --- /dev/null +++ b/test/bijectors/stacked.jl @@ -0,0 +1,26 @@ +struct ProjectionBijector <: Bijectors.Bijector end + +Bijectors.output_size(::ProjectionBijector, sz::Tuple{Int}) = (sz[1] - 1,) +Bijectors.output_size(::Inverse{ProjectionBijector}, sz::Int) = (sz[1] + 1,) + +Bijectors.with_logabsdet_jacobian(::ProjectionBijector, x::AbstractVector) = x[1:(end - 1)], 0 +Bijectors.with_logabsdet_jacobian(::Inverse{ProjectionBijector}, x::AbstractVector) = vcat(x, 0), 0 + +@testset "Stacked with differing input and output size" begin + b = Stacked((elementwise(exp), ProjectionBijector()), (1:1, 2:3)) + binv = inverse(b) + x = [1.0, 2.0, 3.0] + y = b(x) + x_ = binv(y) + + # Are the values of correct size? + @test size(y) == (2,) + @test size(x_) == (3,) + # Can we determine the sizes correctly? + @test Bijectors.output_size(b, size(x)) == (2,) + @test Bijectors.output_size(binv, size(y)) == (3,) + + # Are values correct? + @test y == [exp(1.0), 2.0] + @test binv(y) == [1.0, 2.0, 0.0] +end From def7c6f4aca7a45d3d4c212ea43f50436b93e6a6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 11:36:07 +0100 Subject: [PATCH 18/27] Update test/bijectors/stacked.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/bijectors/stacked.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/bijectors/stacked.jl b/test/bijectors/stacked.jl index 4ae3cef2..185c99ec 100644 --- a/test/bijectors/stacked.jl +++ b/test/bijectors/stacked.jl @@ -3,8 +3,12 @@ struct ProjectionBijector <: Bijectors.Bijector end Bijectors.output_size(::ProjectionBijector, sz::Tuple{Int}) = (sz[1] - 1,) Bijectors.output_size(::Inverse{ProjectionBijector}, sz::Int) = (sz[1] + 1,) -Bijectors.with_logabsdet_jacobian(::ProjectionBijector, x::AbstractVector) = x[1:(end - 1)], 0 -Bijectors.with_logabsdet_jacobian(::Inverse{ProjectionBijector}, x::AbstractVector) = vcat(x, 0), 0 +function Bijectors.with_logabsdet_jacobian(::ProjectionBijector, x::AbstractVector) + return x[1:(end - 1)], 0 +end +function Bijectors.with_logabsdet_jacobian(::Inverse{ProjectionBijector}, x::AbstractVector) + return vcat(x, 0), 0 +end @testset "Stacked with differing input and output size" begin b = Stacked((elementwise(exp), ProjectionBijector()), (1:1, 2:3)) From fe36875efed549858c0870dc99e69a68adeee1c6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 12:30:17 +0100 Subject: [PATCH 19/27] added awful generated function to determine output ranges for Stacked with tuple because recursive implementation fail --- src/bijectors/stacked.jl | 33 ++++++++++++++++++++++++++++++++- test/bijectors/stacked.jl | 36 ++++++++++++++++++++++-------------- 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index cd1ee499..43ce7b74 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -27,7 +27,12 @@ struct Stacked{Bs,Rs<:Union{Tuple,AbstractArray}} <: Transform ranges_out::Rs end -Stacked(bs, ranges) = Stacked(bs, ranges, determine_output_ranges(bs, ranges)) +function Stacked(bs, ranges_in) + ranges_out = determine_output_ranges(bs, ranges_in) + return Stacked{typeof(bs), typeof(ranges_in)}(bs, ranges_in, ranges_out) +end +Stacked(bs::AbstractVector, ranges::Tuple) = Stacked(bs, [ranges...]) +Stacked(bs::Tuple, ranges::AbstractVector) = Stacked([bs...], ranges) Stacked(bs::Tuple) = Stacked(bs, ntuple(i -> i:i, length(bs))) Stacked(bs::AbstractArray) = Stacked(bs, [i:i for i in 1:length(bs)]) Stacked(bs...) = Stacked(bs, ntuple(i -> i:i, length(bs))) @@ -42,6 +47,32 @@ function determine_output_ranges(bs, ranges) end end +# NOTE: I don't like this. +determine_output_ranges(bs::Tuple, ranges::Tuple) = determine_output_ranges_generated(bs, ranges) +@generated function determine_output_ranges_generated(bs::Tuple, ranges::Tuple) + N = length(bs.parameters) + exprs = [] + push!(exprs, :(offset = 0)) + + rsyms = [] + for i = 1:N + rsym = Symbol("r_$i") + lengthsym = Symbol("length_$i") + push!(exprs, :($lengthsym = output_length(bs[$i], length(ranges[$i])))) + push!(exprs, :($rsym = offset .+ (1:$lengthsym))) + push!(exprs, :(offset += $lengthsym)) + + push!(rsyms, rsym) + end + + acc_expr = Expr(:tuple, rsyms...) + + return quote + $(exprs...) + return $acc_expr + end +end + # Avoid mixing tuples and arrays. Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges) diff --git a/test/bijectors/stacked.jl b/test/bijectors/stacked.jl index 185c99ec..866c7e41 100644 --- a/test/bijectors/stacked.jl +++ b/test/bijectors/stacked.jl @@ -11,20 +11,28 @@ function Bijectors.with_logabsdet_jacobian(::Inverse{ProjectionBijector}, x::Abs end @testset "Stacked with differing input and output size" begin - b = Stacked((elementwise(exp), ProjectionBijector()), (1:1, 2:3)) - binv = inverse(b) - x = [1.0, 2.0, 3.0] - y = b(x) - x_ = binv(y) + bs = [ + Stacked((elementwise(exp), ProjectionBijector()), (1:1, 2:3)), + Stacked([elementwise(exp), ProjectionBijector()], [1:1, 2:3]), + Stacked([elementwise(exp), ProjectionBijector()], (1:1, 2:3)), + Stacked((elementwise(exp), ProjectionBijector()), [1:1, 2:3]) + ] + @testset "$b" for b in bs + binv = inverse(b) + x = [1.0, 2.0, 3.0] + y = b(x) + x_ = binv(y) - # Are the values of correct size? - @test size(y) == (2,) - @test size(x_) == (3,) - # Can we determine the sizes correctly? - @test Bijectors.output_size(b, size(x)) == (2,) - @test Bijectors.output_size(binv, size(y)) == (3,) + # Are the values of correct size? + @test size(y) == (2,) + @test size(x_) == (3,) + # Can we determine the sizes correctly? + @test Bijectors.output_size(b, size(x)) == (2,) + @test Bijectors.output_size(binv, size(y)) == (3,) - # Are values correct? - @test y == [exp(1.0), 2.0] - @test binv(y) == [1.0, 2.0, 0.0] + # Are values correct? + @test y == [exp(1.0), 2.0] + @test binv(y) == [1.0, 2.0, 0.0] + end end + From bbfaf1915ccef2b517ff3b998377634a54da35bb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 12:31:27 +0100 Subject: [PATCH 20/27] added slightly more informative comment --- src/bijectors/stacked.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 43ce7b74..80d1752d 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -47,7 +47,7 @@ function determine_output_ranges(bs, ranges) end end -# NOTE: I don't like this. +# NOTE: I don't like this but it seems necessary because `Stacked(...)` can occur in hot code paths. determine_output_ranges(bs::Tuple, ranges::Tuple) = determine_output_ranges_generated(bs, ranges) @generated function determine_output_ranges_generated(bs::Tuple, ranges::Tuple) N = length(bs.parameters) From bf68124412ab38e2dcfbdf90604d32ce31ba8632 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 12:31:47 +0100 Subject: [PATCH 21/27] format --- src/bijectors/stacked.jl | 10 ++++++---- test/bijectors/stacked.jl | 3 +-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 80d1752d..7a29ff48 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -29,7 +29,7 @@ end function Stacked(bs, ranges_in) ranges_out = determine_output_ranges(bs, ranges_in) - return Stacked{typeof(bs), typeof(ranges_in)}(bs, ranges_in, ranges_out) + return Stacked{typeof(bs),typeof(ranges_in)}(bs, ranges_in, ranges_out) end Stacked(bs::AbstractVector, ranges::Tuple) = Stacked(bs, [ranges...]) Stacked(bs::Tuple, ranges::AbstractVector) = Stacked([bs...], ranges) @@ -48,18 +48,20 @@ function determine_output_ranges(bs, ranges) end # NOTE: I don't like this but it seems necessary because `Stacked(...)` can occur in hot code paths. -determine_output_ranges(bs::Tuple, ranges::Tuple) = determine_output_ranges_generated(bs, ranges) +function determine_output_ranges(bs::Tuple, ranges::Tuple) + return determine_output_ranges_generated(bs, ranges) +end @generated function determine_output_ranges_generated(bs::Tuple, ranges::Tuple) N = length(bs.parameters) exprs = [] push!(exprs, :(offset = 0)) rsyms = [] - for i = 1:N + for i in 1:N rsym = Symbol("r_$i") lengthsym = Symbol("length_$i") push!(exprs, :($lengthsym = output_length(bs[$i], length(ranges[$i])))) - push!(exprs, :($rsym = offset .+ (1:$lengthsym))) + push!(exprs, :($rsym = offset .+ (1:($lengthsym)))) push!(exprs, :(offset += $lengthsym)) push!(rsyms, rsym) diff --git a/test/bijectors/stacked.jl b/test/bijectors/stacked.jl index 866c7e41..655b63eb 100644 --- a/test/bijectors/stacked.jl +++ b/test/bijectors/stacked.jl @@ -15,7 +15,7 @@ end Stacked((elementwise(exp), ProjectionBijector()), (1:1, 2:3)), Stacked([elementwise(exp), ProjectionBijector()], [1:1, 2:3]), Stacked([elementwise(exp), ProjectionBijector()], (1:1, 2:3)), - Stacked((elementwise(exp), ProjectionBijector()), [1:1, 2:3]) + Stacked((elementwise(exp), ProjectionBijector()), [1:1, 2:3]), ] @testset "$b" for b in bs binv = inverse(b) @@ -35,4 +35,3 @@ end @test binv(y) == [1.0, 2.0, 0.0] end end - From 45a9850e7bdb5b458a0d307a71f92d224ce2540e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 13:03:40 +0100 Subject: [PATCH 22/27] more fixes to that damned Stacked --- src/bijectors/stacked.jl | 62 ++++++++++++++++++++++++++++++---------- test/interface.jl | 21 +++++++------- 2 files changed, 58 insertions(+), 25 deletions(-) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 7a29ff48..7c955971 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -25,11 +25,15 @@ struct Stacked{Bs,Rs<:Union{Tuple,AbstractArray}} <: Transform bs::Bs ranges_in::Rs ranges_out::Rs + length_in::Int + length_out::Int end function Stacked(bs, ranges_in) ranges_out = determine_output_ranges(bs, ranges_in) - return Stacked{typeof(bs),typeof(ranges_in)}(bs, ranges_in, ranges_out) + return Stacked{typeof(bs),typeof(ranges_in)}( + bs, ranges_in, ranges_out, sum(length, ranges_in), sum(length, ranges_out) + ) end Stacked(bs::AbstractVector, ranges::Tuple) = Stacked(bs, [ranges...]) Stacked(bs::Tuple, ranges::AbstractVector) = Stacked([bs...], ranges) @@ -99,7 +103,11 @@ isclosedform(b::Stacked) = all(isclosedform, b.bs) isinvertible(b::Stacked) = all(isinvertible, b.bs) # For some reason `inverse.(sb.bs)` was unstable... This works though. -inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges_out, sb.ranges_in) +function inverse(sb::Stacked) + return Stacked( + map(inverse, sb.bs), sb.ranges_out, sb.ranges_in, sb.length_out, sb.length_in + ) +end # map is not type stable for many stacked bijectors as a large tuple # hence the generated function @generated function inverse(sb::Stacked{A}) where {A<:Tuple} @@ -107,31 +115,32 @@ inverse(sb::Stacked) = Stacked(map(inverse, sb.bs), sb.ranges_out, sb.ranges_in) for i in 1:length(A.parameters) push!(exprs, :(inverse(sb.bs[$i]))) end - return :(Stacked(($(exprs...),), sb.ranges_out, sb.ranges_in)) + return :(Stacked( + ($(exprs...),), sb.ranges_out, sb.ranges_in, sb.length_out, sb.length_in + )) end -function output_size(b::Stacked, sz::Tuple{Int}) - sz_out = sum(length, b.ranges_out) - return (sz_out,) -end +output_size(b::Stacked, sz::Tuple{Int}) = (b.length_out,) -@generated function _transform(x, rs::NTuple{N,UnitRange{Int}}, bs...) where {N} +@generated function _transform_stacked_recursive( + x, rs::NTuple{N,UnitRange{Int}}, bs... +) where {N} exprs = [] for i in 1:N push!(exprs, :(bs[$i](x[rs[$i]]))) end return :(vcat($(exprs...))) end -function _transform(x, rs::NTuple{1,UnitRange{Int}}, b) - @assert rs[1] == 1:length(x) +function _transform_stacked_recursive(x, rs::NTuple{1,UnitRange{Int}}, b) + rs[1] == 1:length(x) || error("range must be 1:length(x)") return b(x) end -function transform(sb::Stacked{<:Tuple,<:Tuple}, x::AbstractVector{<:Real}) - y = _transform(x, sb.ranges_in, sb.bs...) +function _transform_stacked(sb::Stacked{<:Tuple,<:Tuple}, x::AbstractVector{<:Real}) + y = _transform_stacked_recursive(x, sb.ranges_in, sb.bs...) return y end # The Stacked{<:AbstractArray} version is not TrackedArray friendly -function transform(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real}) +function _transform_stacked(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real}) N = length(sb.bs) N == 1 && return sb.bs[1](x[sb.ranges_in[1]]) @@ -141,6 +150,17 @@ function transform(sb::Stacked{<:AbstractArray}, x::AbstractVector{<:Real}) return y end +function transform(sb::Stacked, x::AbstractVector{<:Real}) + if sb.length_in != length(x) + error("input length mismatch ($(sb.length_in) != $(length(x)))") + end + y = _transform_stacked(sb, x) + if sb.length_out != length(y) + error("output length mismatch ($(sb.length_out) != $(length(y)))") + end + return y +end + function logabsdetjac(b::Stacked, x::AbstractVector{<:Real}) N = length(b.bs) init = sum(logabsdetjac(b.bs[1], x[b.ranges_in[1]])) @@ -177,7 +197,7 @@ end # logjac += sum(_logjac) # return (vcat(y_1, y_2), logjac) # end -@generated function with_logabsdet_jacobian( +@generated function _with_logabsdet_jacobian( b::Stacked{<:NTuple{N,Any},<:NTuple{N,Any}}, x::AbstractVector ) where {N} expr = Expr(:block) @@ -206,7 +226,7 @@ end return expr end -function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector) +function _with_logabsdet_jacobian(sb::Stacked, x::AbstractVector) N = length(sb.bs) yinit, linit = with_logabsdet_jacobian(sb.bs[1], x[sb.ranges_in[1]]) logjac = sum(linit) @@ -217,3 +237,15 @@ function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector) end return (ys, logjac) end + +function with_logabsdet_jacobian(sb::Stacked, x::AbstractVector) + if sb.length_in != length(x) + error("input length mismatch ($(sb.length_in) != $(length(x)))") + end + y, logjac = _with_logabsdet_jacobian(sb, x) + if output_length(sb, length(x)) != length(y) + error("output length mismatch ($(output_length(sb, length(x))) != $(length(y)))") + end + + return (y, logjac) +end diff --git a/test/interface.jl b/test/interface.jl index 7d989abf..fdcac554 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -267,7 +267,7 @@ end @test sb(x) == [exp(x[1]), log(x[2]), x[3] + 5.0] @test res[1] == [exp(x[1]), log(x[2]), x[3] + 5.0] @test logabsdetjac(sb, x) == - sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i in 1:3]) + sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges_in[i]])) for i in 1:3]) @test res[2] == logabsdetjac(sb, x) # TODO: change when we have dimensionality in the type @@ -278,11 +278,11 @@ end @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == - sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i in 1:2]) + sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges_in[i]])) for i in 1:2]) @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 - @test_throws AssertionError sb(x) + @test_throws ErrorException sb(x) # Array-version sb = Stacked([elementwise(exp), SimplexBijector()], [1:1, 2:3]) @@ -292,11 +292,11 @@ end @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == - sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i in 1:2]) + sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges_in[i]])) for i in 1:2]) @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 - @test_throws AssertionError sb(x) + @test_throws ErrorException sb(x) # Mixed versions # Tuple, Array @@ -307,11 +307,11 @@ end @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == - sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i in 1:2]) + sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges_in[i]])) for i in 1:2]) @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 - @test_throws AssertionError sb(x) + @test_throws ErrorException sb(x) # Array, Tuple sb = Stacked((elementwise(exp), SimplexBijector()), [1:1, 2:3]) @@ -321,11 +321,11 @@ end @test sb(x) == [exp(x[1]), sb.bs[2](x[2:3])...] @test res[1] == [exp(x[1]), sb.bs[2](x[2:3])...] @test logabsdetjac(sb, x) == - sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges[i]])) for i in 1:2]) + sum([sum(logabsdetjac(sb.bs[i], x[sb.ranges_in[i]])) for i in 1:2]) @test res[2] == logabsdetjac(sb, x) x = ones(4) ./ 4.0 - @test_throws AssertionError sb(x) + @test_throws ErrorException sb(x) @testset "Stacked: ADVI with MvNormal" begin # MvNormal test @@ -369,7 +369,7 @@ end # check that wrong ranges fails sb = Stacked(ibs) x = rand(d) - @test_throws AssertionError sb(x) + @test_throws ErrorException sb(x) # Stacked{<:Tuple} bs = bijector.(tuple(dists...)) @@ -406,6 +406,7 @@ end end end + @testset "Example: ADVI single" begin # Usage in ADVI d = Beta() From 1f0c374df1d2721e3ab4d116db7a373df83e1c6f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 13:06:22 +0100 Subject: [PATCH 23/27] Update test/interface.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/interface.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/interface.jl b/test/interface.jl index fdcac554..c316ed09 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -406,7 +406,6 @@ end end end - @testset "Example: ADVI single" begin # Usage in ADVI d = Beta() From a917c2bb132d2be3d0555d1fb24334039f59f87b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 13:20:29 +0100 Subject: [PATCH 24/27] specialized constructors for Stacked further --- src/bijectors/stacked.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 7c955971..e1ebb6d8 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -29,14 +29,20 @@ struct Stacked{Bs,Rs<:Union{Tuple,AbstractArray}} <: Transform length_out::Int end -function Stacked(bs, ranges_in) +function Stacked(bs::AbstractArray, ranges_in::AbstractArray) ranges_out = determine_output_ranges(bs, ranges_in) return Stacked{typeof(bs),typeof(ranges_in)}( bs, ranges_in, ranges_out, sum(length, ranges_in), sum(length, ranges_out) ) end -Stacked(bs::AbstractVector, ranges::Tuple) = Stacked(bs, [ranges...]) -Stacked(bs::Tuple, ranges::AbstractVector) = Stacked([bs...], ranges) +function Stacked(bs::Tuple, ranges_in::Tuple) + ranges_out = determine_output_ranges(bs, ranges_in) + return Stacked{typeof(bs),typeof(ranges_in)}( + bs, ranges_in, ranges_out, sum(length, ranges_in), sum(length, ranges_out) + ) +end +Stacked(bs::AbstractArray, ranges::Tuple) = Stacked(bs, [ranges...]) +Stacked(bs::Tuple, ranges::AbstractArray) = Stacked([bs...], ranges) Stacked(bs::Tuple) = Stacked(bs, ntuple(i -> i:i, length(bs))) Stacked(bs::AbstractArray) = Stacked(bs, [i:i for i in 1:length(bs)]) Stacked(bs...) = Stacked(bs, ntuple(i -> i:i, length(bs))) From cdd951a01f751ebf2fb2e83be848a28a13092bb2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 13:46:15 +0100 Subject: [PATCH 25/27] fixed bug in output_size for CorrVecBijector --- src/bijectors/corr.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index 179ed4e1..a4ed4740 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -235,7 +235,7 @@ end function output_size(::VecCorrBijector, sz::Tuple{Int,Int}) sz[1] == sz[2] || error("sizes should be equal; received $(sz)") n = sz[1] - return (n * (n - 1)) ÷ 2 + return ((n * (n - 1)) ÷ 2,) end function output_size(::Inverse{VecCorrBijector}, sz::Tuple{Int}) From 5dbd829c45a082b7c5467b3824722333fe36a745 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 17:11:18 +0100 Subject: [PATCH 26/27] Apply suggestions from code review Co-authored-by: David Widmann --- src/bijectors/stacked.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index e1ebb6d8..73abf51c 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -41,8 +41,8 @@ function Stacked(bs::Tuple, ranges_in::Tuple) bs, ranges_in, ranges_out, sum(length, ranges_in), sum(length, ranges_out) ) end -Stacked(bs::AbstractArray, ranges::Tuple) = Stacked(bs, [ranges...]) -Stacked(bs::Tuple, ranges::AbstractArray) = Stacked([bs...], ranges) +Stacked(bs::AbstractArray, ranges::Tuple) = Stacked(bs, collect(ranges)) +Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges) Stacked(bs::Tuple) = Stacked(bs, ntuple(i -> i:i, length(bs))) Stacked(bs::AbstractArray) = Stacked(bs, [i:i for i in 1:length(bs)]) Stacked(bs...) = Stacked(bs, ntuple(i -> i:i, length(bs))) From 04f69904770f5eabd91e4e1e47fe5eec7dd363b5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 18 Jun 2023 22:45:08 +0100 Subject: [PATCH 27/27] Apply suggestions from code review Co-authored-by: David Widmann --- src/interface.jl | 3 +-- src/transformed_distribution.jl | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index f8c94a31..099df1bb 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -44,9 +44,8 @@ output_size(f, sz) = sz """ output_length(f, len::Int) - output_length(f, sz::Tuple) -Returns the output length of `f` given the input length `len` or size `sz`. +Returns the output length of `f` given the input length `len`. """ output_length(f, len::Int) = only(output_size(f, (len,))) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index bc9f2e90..04c3a559 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -100,7 +100,7 @@ end ############################## # size -Base.length(td::Transformed) = only(output_size(td.transform, size(td.dist))) +Base.length(td::Transformed) = prod(output_size(td.transform, size(td.dist))) Base.size(td::Transformed) = output_size(td.transform, size(td.dist)) function logpdf(td::UnivariateTransformed, y::Real)