Skip to content

Commit 5b58c47

Browse files
committed
Bijectors compat bugfixes
1 parent c26a228 commit 5b58c47

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

src/matrixvariate.jl

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,17 @@ end
3535
Distributions.insupport(::Type{TuringWishart}, X::Matrix) = isposdef(X)
3636
Distributions.insupport(d::TuringWishart, X::Matrix) = size(X) == size(d) && isposdef(X)
3737

38-
dim(d::TuringWishart) = size(d.chol, 1)
39-
Base.size(d::TuringWishart) = (p = dim(d); (p, p))
38+
Distributions.dim(d::TuringWishart) = size(d.chol, 1)
39+
Base.size(d::TuringWishart) = (p = Distributions.dim(d); (p, p))
4040
Base.size(d::TuringWishart, i) = size(d)[i]
41-
LinearAlgebra.rank(d::TuringWishart) = dim(d)
41+
LinearAlgebra.rank(d::TuringWishart) = Distributions.dim(d)
4242

4343
#### Statistics
4444

4545
Distributions.mean(d::TuringWishart) = d.df * Matrix(d.chol)
4646

4747
function Distributions.mode(d::TuringWishart)
48-
r = d.df - dim(d) - 1.0
48+
r = d.df - Distributions.dim(d) - 1.0
4949
if r > 0.0
5050
return Matrix(d.chol) * r
5151
else
@@ -54,7 +54,7 @@ function Distributions.mode(d::TuringWishart)
5454
end
5555

5656
function Distributions.meanlogdet(d::TuringWishart)
57-
p = dim(d)
57+
p = Distributions.dim(d)
5858
df = d.df
5959
v = logdet(d.chol) + p * logtwo
6060
for i = 1:p
@@ -64,7 +64,7 @@ function Distributions.meanlogdet(d::TuringWishart)
6464
end
6565

6666
function Distributions.entropy(d::TuringWishart)
67-
p = dim(d)
67+
p = Distributions.dim(d)
6868
df = d.df
6969
d.c0 - 0.5 * (df - p - 1) * meanlogdet(d) + 0.5 * df * p
7070
end
@@ -84,14 +84,14 @@ end
8484

8585
function Distributions.logpdf(d::TuringWishart, X::AbstractMatrix{<:Real})
8686
df = d.df
87-
p = dim(d)
87+
p = Distributions.dim(d)
8888
Xcf = cholesky(X)
8989
return 0.5 * ((df - (p + 1)) * logdet(Xcf) - tr(d.chol \ X)) - d.c0
9090
end
9191

9292
#### Sampling
9393
function Distributions._rand!(rng::AbstractRNG, d::TuringWishart, A::AbstractMatrix)
94-
_wishart_genA!(rng, dim(d), d.df, A)
94+
_wishart_genA!(rng, Distributions.dim(d), d.df, A)
9595
unwhiten!(d.chol, A)
9696
A .= A * A'
9797
end
@@ -147,16 +147,16 @@ end
147147
Distributions.insupport(::Type{TuringInverseWishart}, X::Matrix) = isposdef(X)
148148
Distributions.insupport(d::TuringInverseWishart, X::Matrix) = size(X) == size(d) && isposdef(X)
149149

150-
dim(d::TuringInverseWishart) = size(d.S, 1)
151-
Base.size(d::TuringInverseWishart) = (p = dim(d); (p, p))
150+
Distributions.dim(d::TuringInverseWishart) = size(d.S, 1)
151+
Base.size(d::TuringInverseWishart) = (p = Distributions.dim(d); (p, p))
152152
Base.size(d::TuringInverseWishart, i) = size(d)[i]
153-
LinearAlgebra.rank(d::TuringInverseWishart) = dim(d)
153+
LinearAlgebra.rank(d::TuringInverseWishart) = Distributions.dim(d)
154154

155155
#### Statistics
156156

157157
function Distributions.mean(d::TuringInverseWishart)
158158
df = d.df
159-
p = dim(d)
159+
p = Distributions.dim(d)
160160
r = df - (p + 1)
161161
if r > 0.0
162162
return d.S * (1.0 / r)
@@ -165,25 +165,25 @@ function Distributions.mean(d::TuringInverseWishart)
165165
end
166166
end
167167

168-
Distributions.mode(d::TuringInverseWishart) = d.S * inv(d.df + dim(d) + 1.0)
168+
Distributions.mode(d::TuringInverseWishart) = d.S * inv(d.df + Distributions.dim(d) + 1.0)
169169

170170
# https://en.wikipedia.org/wiki/Inverse-Wishart_distribution#Moments
171171
function Distributions.cov(d::TuringInverseWishart, i::Integer, j::Integer, k::Integer, l::Integer)
172-
p, ν, Ψ = (dim(d), d.df, d.S)
172+
p, ν, Ψ = (Distributions.dim(d), d.df, d.S)
173173
ν > p + 3 || throw(ArgumentError("cov only defined for df > dim + 3"))
174174
inv((ν - p)*- p - 3)*- p - 1)^2)*(2Ψ[i,j]*Ψ[k,l] +-p-1)*(Ψ[i,k]*Ψ[j,l] + Ψ[i,l]*Ψ[k,j]))
175175
end
176176

177177
function Distributions.var(d::TuringInverseWishart, i::Integer, j::Integer)
178-
p, ν, Ψ = (dim(d), d.df, d.S)
178+
p, ν, Ψ = (Distributions.dim(d), d.df, d.S)
179179
ν > p + 3 || throw(ArgumentError("var only defined for df > dim + 3"))
180180
inv((ν - p)*- p - 3)*- p - 1)^2)*- p + 1)*Ψ[i,j]^2 +- p - 1)*Ψ[i,i]*Ψ[j,j]
181181
end
182182

183183
#### Evaluation
184184

185185
function Distributions.logpdf(d::TuringInverseWishart, X::AbstractMatrix{<:Real})
186-
p = dim(d)
186+
p = Distributions.dim(d)
187187
df = d.df
188188
Xcf = cholesky(X)
189189
# we use the fact: tr(Ψ * inv(X)) = tr(inv(X) * Ψ) = tr(X \ Ψ)
@@ -194,8 +194,10 @@ end
194194

195195
#### Sampling
196196

197-
Distributions._rand!(rng::AbstractRNG, d::TuringInverseWishart, A::AbstractMatrix) =
198-
(A .= inv(cholesky!(_rand!(rng, TuringWishart(d.df, inv(cholesky(d.S))), A))))
197+
function Distributions._rand!(rng::AbstractRNG, d::TuringInverseWishart, A::AbstractMatrix)
198+
X = Distributions._rand!(rng, TuringWishart(d.df, inv(cholesky(d.S))), A)
199+
A .= inv(cholesky!(X))
200+
end
199201

200202
## Adjoints
201203

0 commit comments

Comments
 (0)