Skip to content

Commit e5cd924

Browse files
committed
minor fixes
1 parent 6de8814 commit e5cd924

File tree

2 files changed

+13
-2
lines changed

2 files changed

+13
-2
lines changed

src/multivariate.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,16 @@ for T in (:TrackedVector, :TrackedMatrix)
175175
function Distributions.logpdf(d::MvNormal{<:Any, <:PDMats.PDMat}, x::$T)
176176
logpdf(TuringDenseMvNormal(d.μ, d.Σ.chol), x)
177177
end
178+
179+
function Distributions.logpdf(d::MvLogNormal{<:Any, <:PDMats.ScalMat}, x::$T)
180+
logpdf(TuringMvLogNormal(TuringScalMvNormal(d.normal.μ, d.normal.Σ.value)), x)
181+
end
182+
function Distributions.logpdf(d::MvLogNormal{<:Any, <:PDMats.PDiagMat}, x::$T)
183+
logpdf(TuringMvLogNormal(TuringDiagMvNormal(d.normal.μ, d.normal.Σ.diag)), x)
184+
end
185+
function Distributions.logpdf(d::MvLogNormal{<:Any, <:PDMats.PDMat}, x::$T)
186+
logpdf(TuringMvLogNormal(TuringDenseMvNormal(d.normal.μ, d.normal.Σ.chol)), x)
187+
end
178188
end
179189
end
180190

src/univariate.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,15 @@ function TuringUniform(a::Real, b::Real)
1313
return TuringUniform{T}(T(a), T(b))
1414
end
1515
Distributions.logpdf(d::TuringUniform, x::Real) = uniformlogpdf(d.a, d.b, x)
16+
Distributions.logpdf(d::TuringUniform, x::AbstractArray) = uniformlogpdf.(d.a, d.b, x)
1617
Base.minimum(d::TuringUniform) = d.a
1718
Base.maximum(d::TuringUniform) = d.b
1819

1920
Distributions.Uniform(a::TrackedReal, b::Real) = TuringUniform{TrackedReal}(a, b)
2021
Distributions.Uniform(a::Real, b::TrackedReal) = TuringUniform{TrackedReal}(a, b)
2122
Distributions.Uniform(a::TrackedReal, b::TrackedReal) = TuringUniform{TrackedReal}(a, b)
2223
Distributions.logpdf(d::Uniform, x::TrackedReal) = uniformlogpdf(d.a, d.b, x)
23-
24+
Distributions.logpdf(d::Uniform, x::TrackedArray) = uniformlogpdf.(d.a, d.b, x)
2425
function uniformlogpdf(a, b, x)
2526
c = -log(b - a)
2627
if a <= x <= b
@@ -53,7 +54,7 @@ end
5354
return l, Δ -> (da * Δ, -da * Δ, zero(T) * Δ)
5455
else
5556
n = T(NaN)
56-
return l, Δ -> (n, n, n)
57+
return n, Δ -> (n, n, n)
5758
end
5859
end
5960
@adjoint function Distributions.Uniform(args...)

0 commit comments

Comments
 (0)