Skip to content

Commit 52b0367

Browse files
committed
Fix zero-type of logjac for ReshapeTransform
1 parent 9df42bf commit 52b0367

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

src/utils.jl

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,12 +252,16 @@ function (f::UnwrapSingletonTransform)(x)
252252
return only(x)
253253
end
254254

255-
Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x) = (f(x), 0)
255+
function Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x)
256+
return f(x), zero(eltype(x))
257+
end
258+
256259
function Bijectors.with_logabsdet_jacobian(
257260
inv_f::Bijectors.Inverse{<:UnwrapSingletonTransform}, x
258261
)
259262
f = inv_f.orig
260-
return (reshape([x], f.input_size), 0)
263+
result = reshape([x], f.input_size)
264+
return result, zero(eltype(x))
261265
end
262266

263267
"""
@@ -306,18 +310,24 @@ function (inv_f::Bijectors.Inverse{<:ReshapeTransform})(x)
306310
return inverse(x)
307311
end
308312

309-
Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x) = (f(x), 0)
313+
Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x) = (f(x), zero(eltype(x)))
310314

311315
function Bijectors.with_logabsdet_jacobian(inv_f::Bijectors.Inverse{<:ReshapeTransform}, x)
312-
return (inv_f(x), 0)
316+
return inv_f(x), zero(eltype(x))
313317
end
314318

315319
struct ToChol <: Bijectors.Bijector
316320
uplo::Char
317321
end
318322

319-
Bijectors.with_logabsdet_jacobian(f::ToChol, x) = (Cholesky(Matrix(x), f.uplo, 0), 0)
320-
Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y::Cholesky) = (y.UL, 0)
323+
function Bijectors.with_logabsdet_jacobian(f::ToChol, x)
324+
return Cholesky(Matrix(x), f.uplo, 0), zero(eltype(x))
325+
end
326+
327+
function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y::Cholesky)
328+
return y.UL, zero(eltype(y))
329+
end
330+
321331
function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y)
322332
return error(
323333
"Inverse{ToChol} is only defined for Cholesky factorizations. " *

0 commit comments

Comments
 (0)