Skip to content

Commit d04d20c

Browse files
projekterDesef, Benjamindevmotion
authored
Make MatrixNormal sampling non-allocating (#2012)
* Make MatrixNormal sampling non-allocating * Add test for non-allocating MatrixNormal sampling * Update non-allocating test Co-authored-by: David Müller-Widmann <devmotion@users.noreply.github.com> * Fix Co-authored-by: David Müller-Widmann <devmotion@users.noreply.github.com> * Adjust allocation test to Julia <1.10.5 failure Co-authored-by: David Müller-Widmann <devmotion@users.noreply.github.com> --------- Co-authored-by: Desef, Benjamin <benjamin.desef@dlr.de> Co-authored-by: David Müller-Widmann <devmotion@users.noreply.github.com>
1 parent 821c38f commit d04d20c

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

src/matrix/matrixnormal.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,13 @@ end
122122
# https://en.wikipedia.org/wiki/Matrix_normal_distribution#Drawing_values_from_the_distribution
123123
124124
function _rand!(rng::AbstractRNG, d::MatrixNormal, Y::AbstractMatrix)
125-
n, p = size(d)
126-
X = randn(rng, n, p)
125+
randn!(rng, Y)
127126
A = cholesky(d.U).L
128127
B = cholesky(d.V).U
129-
Y .= d.M .+ A * X * B
128+
lmul!(A, Y)
129+
rmul!(Y, B)
130+
Y .+= d.M
131+
return Y
130132
end
131133
132134
# -----------------------------------------------------------------------------

test/matrixvariates.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,19 @@ function test_special(dist::Type{MatrixNormal})
322322
end
323323
end
324324
end
325+
@testset "Non-allocating sampling" begin
326+
# #2012: we can sample without allocations
327+
M, U, V = _rand_params(MatrixNormal, Float64, 5, 5)
328+
noallocD = MatrixNormal(M, cholesky!(Symmetric(U, :L)), cholesky!(Symmetric(V, :U)))
329+
output = Matrix{Float64}(undef, size(noallocD))
330+
allocs = ((d, out) -> @allocated(rand!(d, out)))(noallocD, output)
331+
# See https://github.com/JuliaStats/Distributions.jl/pull/2012#issuecomment-3566807876
332+
if VERSION < v"1.10.5"
333+
@test allocs <= 32
334+
else
335+
@test iszero(allocs)
336+
end
337+
end
325338
nothing
326339
end
327340

0 commit comments

Comments
 (0)