Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.35.3"
version = "0.35.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
22 changes: 16 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,16 @@ function (f::UnwrapSingletonTransform)(x)
return only(x)
end

Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x) = (f(x), 0)
function Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x)
return f(x), zero(eltype(x))
end

function Bijectors.with_logabsdet_jacobian(
inv_f::Bijectors.Inverse{<:UnwrapSingletonTransform}, x
)
f = inv_f.orig
return (reshape([x], f.input_size), 0)
result = reshape([x], f.input_size)
return result, zero(eltype(x))
end

"""
Expand Down Expand Up @@ -306,18 +310,24 @@ function (inv_f::Bijectors.Inverse{<:ReshapeTransform})(x)
return inverse(x)
end

Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x) = (f(x), 0)
Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x) = (f(x), zero(eltype(x)))

function Bijectors.with_logabsdet_jacobian(inv_f::Bijectors.Inverse{<:ReshapeTransform}, x)
return (inv_f(x), 0)
return inv_f(x), zero(eltype(x))
end

struct ToChol <: Bijectors.Bijector
uplo::Char
end

Bijectors.with_logabsdet_jacobian(f::ToChol, x) = (Cholesky(Matrix(x), f.uplo, 0), 0)
Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y::Cholesky) = (y.UL, 0)
function Bijectors.with_logabsdet_jacobian(f::ToChol, x)
return Cholesky(Matrix(x), f.uplo, 0), zero(eltype(x))
end

function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y::Cholesky)
return y.UL, zero(eltype(y))
end

function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y)
return error(
"Inverse{ToChol} is only defined for Cholesky factorizations. " *
Expand Down
Loading