Skip to content

Commit fabdc44

Browse files
committed
Make logpdf of NoDist be of the eltype of the argument
1 parent 4bc00f4 commit fabdc44

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

src/distribution_wrappers.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -54,30 +54,30 @@ function Distributions.rand!(
5454
) where {N}
5555
return Distributions.rand!(rng, d.dist, x)
5656
end
57-
Distributions.logpdf(d::NoDist{<:Univariate}, ::Real) = 0
58-
Distributions.logpdf(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0
59-
function Distributions.logpdf(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
60-
return zeros(Int, size(x, 2))
57+
Distributions.logpdf(::NoDist{<:Univariate}, x::Real) = zero(eltype(x))
58+
Distributions.logpdf(::NoDist{<:Multivariate}, x::AbstractVector{<:Real}) = zero(eltype(x))
59+
function Distributions.logpdf(::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real})
60+
return zeros(eltype(x), size(x, 2))
6161
end
62-
Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0
62+
Distributions.logpdf(::NoDist{<:Matrixvariate}, x::AbstractMatrix{<:Real}) = zero(eltype(x))
6363
Distributions.minimum(d::NoDist) = minimum(d.dist)
6464
Distributions.maximum(d::NoDist) = maximum(d.dist)
6565

66-
Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0
66+
Bijectors.logpdf_with_trans(::NoDist{<:Univariate}, x::Real, ::Bool) = zero(eltype(x))
6767
function Bijectors.logpdf_with_trans(
68-
d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}, ::Bool
68+
::NoDist{<:Multivariate}, x::AbstractVector{<:Real}, ::Bool
6969
)
70-
return 0
70+
return zero(eltype(x))
7171
end
7272
function Bijectors.logpdf_with_trans(
73-
d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool
73+
::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool
7474
)
75-
return zeros(Int, size(x, 2))
75+
return zeros(eltype(x), size(x, 2))
7676
end
7777
function Bijectors.logpdf_with_trans(
78-
d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}, ::Bool
78+
::NoDist{<:Matrixvariate}, x::AbstractMatrix{<:Real}, ::Bool
7979
)
80-
return 0
80+
return zero(eltype(x))
8181
end
8282

8383
Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist)

0 commit comments

Comments
 (0)