Skip to content

Commit eb7ebd6

Browse files
author
Aidan Gleich
committed
Fix rand function for DegenerateMvNormal distributions
1 parent 1776c9c commit eb7ebd6

File tree

1 file changed

+21
-1
lines changed

1 file changed

+21
-1
lines changed

src/distributions_ext.jl

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,27 @@ Distributions.rand(d::DegenerateMvNormal; cc::T = 1.0) where T<:AbstractFloat
178178
Generate a draw from `d` with variance optionally scaled by `cc^2`.
179179
"""
180180
function Distributions.rand(d::DegenerateMvNormal; cc::T = 1.0) where T<:AbstractFloat
181-
return d.μ + cc*d.σ*randn(length(d))
181+
# abusing notation slightly, if Y is a degen MV normal r.v. with covariance matrix Σ,
182+
# and Σ = U Λ^2 Vt according to the svd, then given an standard MV normal r.v X with
183+
# the same dimension as Y, Y = μ + UΛX.
184+
185+
# we need to ensure symmetry when computing SVD
186+
U, λ_vals, Vt = svd((d.σ + d.σ')./2)
187+
188+
# set near-zero values to zero
189+
λ_vals[λ_vals .< 10^(-6)] .= 0
190+
191+
# leave x as 0 where λ_vals equals 0 (b/c r.v. is fixed where λ_vals = 0)
192+
λ_vals = abs.(λ_vals)
193+
x = zeros(length(λ_vals))
194+
for i in 1:length(λ_vals)
195+
if λ_vals[i] == 0
196+
x[i] = 0
197+
else
198+
x[i] = randn()
199+
end
200+
end
201+
return d.μ + cc*U*diagm(sqrt.(λ_vals))*x
182202
end
183203

184204
"""

0 commit comments

Comments
 (0)