From 52b036726e80eed362e94d9ba3dba9f2ae06d200 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 18 Mar 2025 17:59:17 +0000 Subject: [PATCH 1/5] Fix zero-type of logjac for ReshapeTransform --- src/utils.jl | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index d64f6dc66..ca0f4436e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -252,12 +252,16 @@ function (f::UnwrapSingletonTransform)(x) return only(x) end -Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x) = (f(x), 0) +function Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x) + return f(x), zero(eltype(x)) +end + function Bijectors.with_logabsdet_jacobian( inv_f::Bijectors.Inverse{<:UnwrapSingletonTransform}, x ) f = inv_f.orig - return (reshape([x], f.input_size), 0) + result = reshape([x], f.input_size) + return result, zero(eltype(x)) end """ @@ -306,18 +310,24 @@ function (inv_f::Bijectors.Inverse{<:ReshapeTransform})(x) return inverse(x) end -Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x) = (f(x), 0) +Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x) = (f(x), zero(eltype(x))) function Bijectors.with_logabsdet_jacobian(inv_f::Bijectors.Inverse{<:ReshapeTransform}, x) - return (inv_f(x), 0) + return inv_f(x), zero(eltype(x)) end struct ToChol <: Bijectors.Bijector uplo::Char end -Bijectors.with_logabsdet_jacobian(f::ToChol, x) = (Cholesky(Matrix(x), f.uplo, 0), 0) -Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y::Cholesky) = (y.UL, 0) +function Bijectors.with_logabsdet_jacobian(f::ToChol, x) + return Cholesky(Matrix(x), f.uplo, 0), zero(eltype(x)) +end + +function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y::Cholesky) + return y.UL, zero(eltype(y)) +end + function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y) return error( "Inverse{ToChol} is only defined for Cholesky factorizations. " * From 4bc00f4e299f1c53fe55054cba0f96571d0c492e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 18 Mar 2025 18:04:58 +0000 Subject: [PATCH 2/5] Bump patch version to 0.35.4 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e546dae9d..86bcacd7f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.35.3" +version = "0.35.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From fabdc444538f6f2a0c05f277e4bd5f039af8f4db Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 18 Mar 2025 18:41:36 +0000 Subject: [PATCH 3/5] Make logpdf of NoDist be of the eltype of the argument --- src/distribution_wrappers.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index d7097f5b4..5f4c68f82 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -54,30 +54,30 @@ function Distributions.rand!( ) where {N} return Distributions.rand!(rng, d.dist, x) end -Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0 -Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0 -function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}) - return zeros(Int, size(x, 2)) +Distributions.logpdf(::NoDist{<:Univariate}, x::Real) = zero(eltype(x)) +Distributions.logpdf(::NoDist{<:Multivariate}, x::AbstractVector{<:Real}) = zero(eltype(x)) +function Distributions.logpdf(::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}) + return zeros(eltype(x), size(x, 2)) end -Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0 +Distributions.logpdf(::NoDist{<:Matrixvariate}, x::AbstractMatrix{<:Real}) = zero(eltype(x)) Distributions.minimum(d::NoDist) = minimum(d.dist) Distributions.maximum(d::NoDist) = maximum(d.dist) -Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0 +Bijectors.logpdf_with_trans(::NoDist{<:Univariate}, x::Real, ::Bool) = zero(eltype(x)) function Bijectors.logpdf_with_trans( - d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}, ::Bool + ::NoDist{<:Multivariate}, x::AbstractVector{<:Real}, ::Bool ) - return 0 + return zero(eltype(x)) end function Bijectors.logpdf_with_trans( - d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool + ::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool ) - return zeros(Int, size(x, 2)) + return zeros(eltype(x), size(x, 2)) end function Bijectors.logpdf_with_trans( - d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}, ::Bool + ::NoDist{<:Matrixvariate}, x::AbstractMatrix{<:Real}, ::Bool ) - return 0 + return zero(eltype(x)) end Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist) From 93c1cbb478fae0a2bffebe0c32c201e81bdd18d9 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 18 Mar 2025 18:56:14 +0000 Subject: [PATCH 4/5] Use float_type_with_fallback for logjacs and logpdfs --- src/distribution_wrappers.jl | 24 ++++++++++++++++-------- src/utils.jl | 14 ++++++++------ 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 5f4c68f82..9dd04e0b4 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -54,30 +54,38 @@ function Distributions.rand!( ) where {N} return Distributions.rand!(rng, d.dist, x) end -Distributions.logpdf(::NoDist{<:Univariate}, x::Real) = zero(eltype(x)) -Distributions.logpdf(::NoDist{<:Multivariate}, x::AbstractVector{<:Real}) = zero(eltype(x)) +function Distributions.logpdf(::NoDist{<:Univariate}, x::Real) + return zero(float_type_with_fallback(eltype(x))) +end +function Distributions.logpdf(::NoDist{<:Multivariate}, x::AbstractVector{<:Real}) + return zero(float_type_with_fallback(eltype(x))) +end function Distributions.logpdf(::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}) - return zeros(eltype(x), size(x, 2)) + return zeros(float_type_with_fallback(eltype(x)), size(x, 2)) +end +function Distributions.logpdf(::NoDist{<:Matrixvariate}, x::AbstractMatrix{<:Real}) + return zero(float_type_with_fallback(eltype(x))) end -Distributions.logpdf(::NoDist{<:Matrixvariate}, x::AbstractMatrix{<:Real}) = zero(eltype(x)) Distributions.minimum(d::NoDist) = minimum(d.dist) Distributions.maximum(d::NoDist) = maximum(d.dist) -Bijectors.logpdf_with_trans(::NoDist{<:Univariate}, x::Real, ::Bool) = zero(eltype(x)) +function Bijectors.logpdf_with_trans(::NoDist{<:Univariate}, x::Real, ::Bool) + return zero(float_type_with_fallback(eltype(x))) +end function Bijectors.logpdf_with_trans( ::NoDist{<:Multivariate}, x::AbstractVector{<:Real}, ::Bool ) - return zero(eltype(x)) + return zero(float_type_with_fallback(eltype(x))) end function Bijectors.logpdf_with_trans( ::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool ) - return zeros(eltype(x), size(x, 2)) + return zeros(float_type_with_fallback(eltype(x)), size(x, 2)) end function Bijectors.logpdf_with_trans( ::NoDist{<:Matrixvariate}, x::AbstractMatrix{<:Real}, ::Bool ) - return zero(eltype(x)) + return zero(float_type_with_fallback(eltype(x))) end Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist) diff --git a/src/utils.jl b/src/utils.jl index ca0f4436e..f54baaa87 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -253,7 +253,7 @@ function (f::UnwrapSingletonTransform)(x) end function Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x) - return f(x), zero(eltype(x)) + return f(x), zero(float_type_with_fallback(eltype(x))) end function Bijectors.with_logabsdet_jacobian( @@ -261,7 +261,7 @@ function Bijectors.with_logabsdet_jacobian( ) f = inv_f.orig result = reshape([x], f.input_size) - return result, zero(eltype(x)) + return result, zero(float_type_with_fallback(eltype(x))) end """ @@ -310,10 +310,12 @@ function (inv_f::Bijectors.Inverse{<:ReshapeTransform})(x) return inverse(x) end -Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x) = (f(x), zero(eltype(x))) +function Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x) + return f(x), zero(float_type_with_fallback(eltype(x))) +end function Bijectors.with_logabsdet_jacobian(inv_f::Bijectors.Inverse{<:ReshapeTransform}, x) - return inv_f(x), zero(eltype(x)) + return inv_f(x), zero(float_type_with_fallback(eltype(x))) end struct ToChol <: Bijectors.Bijector @@ -321,11 +323,11 @@ struct ToChol <: Bijectors.Bijector end function Bijectors.with_logabsdet_jacobian(f::ToChol, x) - return Cholesky(Matrix(x), f.uplo, 0), zero(eltype(x)) + return Cholesky(Matrix(x), f.uplo, 0), zero(float_type_with_fallback(eltype(x))) end function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y::Cholesky) - return y.UL, zero(eltype(y)) + return y.UL, zero(float_type_with_fallback(eltype(y))) end function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y) From 9aa8d7bb5795530e6dd45a62b663dbb163d56959 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 19 Mar 2025 11:49:54 +0000 Subject: [PATCH 5/5] Make LogProbType be float(Real) --- src/distribution_wrappers.jl | 16 ++++++++-------- src/utils.jl | 22 ++++++++++++++++------ src/varinfo.jl | 2 +- 3 files changed, 25 insertions(+), 15 deletions(-) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 9dd04e0b4..0c7fc446a 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -55,37 +55,37 @@ function Distributions.rand!( return Distributions.rand!(rng, d.dist, x) end function Distributions.logpdf(::NoDist{<:Univariate}, x::Real) - return zero(float_type_with_fallback(eltype(x))) + return zero(LogProbType) end function Distributions.logpdf(::NoDist{<:Multivariate}, x::AbstractVector{<:Real}) - return zero(float_type_with_fallback(eltype(x))) + return zero(LogProbType) end function Distributions.logpdf(::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}) - return zeros(float_type_with_fallback(eltype(x)), size(x, 2)) + return zeros(LogProbType, size(x, 2)) end function Distributions.logpdf(::NoDist{<:Matrixvariate}, x::AbstractMatrix{<:Real}) - return zero(float_type_with_fallback(eltype(x))) + return zero(LogProbType) end Distributions.minimum(d::NoDist) = minimum(d.dist) Distributions.maximum(d::NoDist) = maximum(d.dist) function Bijectors.logpdf_with_trans(::NoDist{<:Univariate}, x::Real, ::Bool) - return zero(float_type_with_fallback(eltype(x))) + return zero(LogProbType) end function Bijectors.logpdf_with_trans( ::NoDist{<:Multivariate}, x::AbstractVector{<:Real}, ::Bool ) - return zero(float_type_with_fallback(eltype(x))) + return zero(LogProbType) end function Bijectors.logpdf_with_trans( ::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool ) - return zeros(float_type_with_fallback(eltype(x)), size(x, 2)) + return zeros(LogProbType, size(x, 2)) end function Bijectors.logpdf_with_trans( ::NoDist{<:Matrixvariate}, x::AbstractMatrix{<:Real}, ::Bool ) - return zero(float_type_with_fallback(eltype(x))) + return zero(LogProbType) end Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist) diff --git a/src/utils.jl b/src/utils.jl index f54baaa87..50f9baf61 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,6 +5,16 @@ const NO_DEFAULT = NoDefault() # A short-hand for a type commonly used in type signatures for VarInfo methods. VarNameTuple = NTuple{N,VarName} where {N} +# TODO(mhauru) This is currently used in the transformation functions of NoDist, +# ReshapeTransform, and UnwrapSingletonTransform, and in VarInfo. We should also use it in +# SimpleVarInfo and maybe other places. +""" +The type for all log probability variables. + +This is Float64 on 64-bit systems and Float32 on 32-bit systems. +""" +const LogProbType = float(Real) + """ @addlogprob!(ex) @@ -253,7 +263,7 @@ function (f::UnwrapSingletonTransform)(x) end function Bijectors.with_logabsdet_jacobian(f::UnwrapSingletonTransform, x) - return f(x), zero(float_type_with_fallback(eltype(x))) + return f(x), zero(LogProbType) end function Bijectors.with_logabsdet_jacobian( @@ -261,7 +271,7 @@ function Bijectors.with_logabsdet_jacobian( ) f = inv_f.orig result = reshape([x], f.input_size) - return result, zero(float_type_with_fallback(eltype(x))) + return result, zero(LogProbType) end """ @@ -311,11 +321,11 @@ function (inv_f::Bijectors.Inverse{<:ReshapeTransform})(x) end function Bijectors.with_logabsdet_jacobian(f::ReshapeTransform, x) - return f(x), zero(float_type_with_fallback(eltype(x))) + return f(x), zero(LogProbType) end function Bijectors.with_logabsdet_jacobian(inv_f::Bijectors.Inverse{<:ReshapeTransform}, x) - return inv_f(x), zero(float_type_with_fallback(eltype(x))) + return inv_f(x), zero(LogProbType) end struct ToChol <: Bijectors.Bijector @@ -323,11 +333,11 @@ struct ToChol <: Bijectors.Bijector end function Bijectors.with_logabsdet_jacobian(f::ToChol, x) - return Cholesky(Matrix(x), f.uplo, 0), zero(float_type_with_fallback(eltype(x))) + return Cholesky(Matrix(x), f.uplo, 0), zero(LogProbType) end function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y::Cholesky) - return y.UL, zero(float_type_with_fallback(eltype(y))) + return y.UL, zero(LogProbType) end function Bijectors.with_logabsdet_jacobian(::Bijectors.Inverse{<:ToChol}, y) diff --git a/src/varinfo.jl b/src/varinfo.jl index d27a82437..2fd5894aa 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -821,7 +821,7 @@ end # VarInfo -VarInfo(meta=Metadata()) = VarInfo(meta, Ref{Float64}(0.0), Ref(0)) +VarInfo(meta=Metadata()) = VarInfo(meta, Ref{LogProbType}(0.0), Ref(0)) function TypedVarInfo(vi::VectorVarInfo) new_metas = group_by_symbol(vi.metadata)