Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6b46e34
added output_length and output_size to compute output, well, leengths
torfjelde Jun 17, 2023
41d0b06
added tests for size of transformed dist using VcCorrBijector
torfjelde Jun 17, 2023
a76f18a
use already constructed transfrormation
torfjelde Jun 17, 2023
6afc77e
TransformedDistribution should now also have correct variate form
torfjelde Jun 17, 2023
a4c5683
added proper variateform handling for VecCholeskyBijector too
torfjelde Jun 17, 2023
ea724ee
Apply suggestions from code review
torfjelde Jun 17, 2023
387ef5a
added output_size impl for Reshape too
torfjelde Jun 17, 2023
acb5e8f
bump minor version
torfjelde Jun 18, 2023
3391735
Apply suggestions from code review
torfjelde Jun 18, 2023
b524ebb
Update src/interface.jl
torfjelde Jun 18, 2023
d6dc906
Update src/bijectors/corr.jl
torfjelde Jun 18, 2023
280708b
reverted removal of length as we'll need it now
torfjelde Jun 18, 2023
2069d69
updated Stacked to be compat with changing sizes
torfjelde Jun 18, 2023
f533a79
forgot to commit deetion
torfjelde Jun 18, 2023
56b8834
Apply suggestions from code review
torfjelde Jun 18, 2023
098a9c0
add testing of sizes to `test_bijector`
torfjelde Jun 18, 2023
4e14bb2
some more tests for stacked
torfjelde Jun 18, 2023
def7c6f
Update test/bijectors/stacked.jl
torfjelde Jun 18, 2023
fe36875
added awful generated function to determine output ranges for Stacked
torfjelde Jun 18, 2023
bbfaf19
added slightly more informative comment
torfjelde Jun 18, 2023
bf68124
format
torfjelde Jun 18, 2023
45a9850
more fixes to that damned Stacked
torfjelde Jun 18, 2023
1f0c374
Update test/interface.jl
torfjelde Jun 18, 2023
a917c2b
specialized constructors for Stacked further
torfjelde Jun 18, 2023
cdd951a
fixed bug in output_size for CorrVecBijector
torfjelde Jun 18, 2023
5dbd829
Apply suggestions from code review
torfjelde Jun 18, 2023
04f6990
Apply suggestions from code review
torfjelde Jun 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ function logabsdetjac(::Inverse{VecCorrBijector}, y::AbstractVector{<:Real})
return _logabsdetjac_inv_corr(y)
end

function output_size(::VecCorrBijector, sz::NTuple{2})
@assert sz[1] == sz[2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make this a proper, more descriptive error?

n = sz[1]
return (n * (n - 1)) ÷ 2
end

function output_size(::Inverse{VecCorrBijector}, sz::NTuple{1})
n = _triu1_dim_from_length(first(sz))
return (n, n)
end

"""
VecCholeskyBijector <: Bijector

Expand Down
20 changes: 20 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ function logabsdetjac(f::Columnwise, x::AbstractMatrix)
end
with_logabsdet_jacobian(f::Columnwise, x::AbstractMatrix) = (f(x), logabsdetjac(f, x))

"""
output_size(f, sz)

Returns the output size of `f` given the input size `sz`.
"""
output_size(f, sz) = sz

"""
output_length(f, len::Int)
output_length(f, sz::Tuple)

Returns the output length of `f` given the input length `len` or size `sz`.
"""
output_length(f, len::Int) = len
function output_length(f, len::Tuple)
sz = output_size(f, len)
@assert length(sz) == 1
return first(sz)
end

######################
# Bijector interface #
######################
Expand Down
26 changes: 11 additions & 15 deletions src/transformed_distribution.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
function variateform(d::Distribution{<:ArrayLikeVariate}, b)
sz_in = size(d)
sz_out = output_size(b, sz_in)
return ArrayLikeVariate{length(sz_out)}
end

# Transformed distributions
struct TransformedDistribution{D,B,V} <:
Distribution{V,Continuous} where {D<:Distribution{V,Continuous},B}
struct TransformedDistribution{D,B,V} <: Distribution{V,Continuous} where {D<:ContinuousDistribution,B}
dist::D
transform::B

function TransformedDistribution(d::UnivariateDistribution, b)
return new{typeof(d),typeof(b),Univariate}(d, b)
end
function TransformedDistribution(d::MultivariateDistribution, b)
return new{typeof(d),typeof(b),Multivariate}(d, b)
end
function TransformedDistribution(d::MatrixDistribution, b)
return new{typeof(d),typeof(b),Matrixvariate}(d, b)
end
function TransformedDistribution(d::Distribution{CholeskyVariate}, b)
return new{typeof(d),typeof(b),CholeskyVariate}(d, b)
function TransformedDistribution(d::ContinuousDistribution, b)
return new{typeof(d),typeof(b),variateform(d,b)}(d, b)
end
end

Expand Down Expand Up @@ -101,8 +97,8 @@ end
##############################

# size
Base.length(td::Transformed) = length(td.dist)
Base.size(td::Transformed) = size(td.dist)
Base.length(td::Transformed) = output_length(td.transform, size(td.dist))
Base.size(td::Transformed) = output_size(td.transform, size(td.dist))

function logpdf(td::UnivariateTransformed, y::Real)
x, logjac = with_logabsdet_jacobian(inverse(td.transform), y)
Expand Down
9 changes: 9 additions & 0 deletions test/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ using Bijectors: VecCorrBijector, VecCholeskyBijector, CorrBijector
test_bijector(bvec, x; test_not_identity=d != 1, changes_of_variables_test=false)

test_ad(x -> sum(bvec(bvecinv(x))), yvec)

# Check that output sizes are computed correctly.
dist = transformed(dist)
@test length(dist) == length(yvec)
@test dist isa MultivariateDistribution

dist_unconstrained = transformed(MvNormal(zeros(length(dist)), I), inverse(bvec))
@test size(dist_unconstrained) == size(x)
@test dist_unconstrained isa MatrixDistribution
end
end

Expand Down