From 182b772ae3d907df4b51d491d238ce775ac0546d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 12 Apr 2025 01:50:04 +0200 Subject: [PATCH 1/5] Fix MixtureModel with ScalMat covariances --- src/mixtures/mixturemodel.jl | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/mixtures/mixturemodel.jl b/src/mixtures/mixturemodel.jl index f25fb82d9c..44353fc732 100644 --- a/src/mixtures/mixturemodel.jl +++ b/src/mixtures/mixturemodel.jl @@ -186,7 +186,7 @@ function mean(d::MultivariateMixture) pi = p[i] if pi > 0.0 c = component(d, i) - BLAS.axpy!(pi, mean(c), m) + m .= muladd.(pi, mean(c), m) end end return m @@ -236,17 +236,21 @@ function cov(d::MultivariateMixture) pi = p[i] if pi > 0.0 c = component(d, i) - BLAS.axpy!(pi, mean(c), m) - BLAS.axpy!(pi, cov(c), V) + m .= muladd.(pi, mean(c), m) + V .= muladd.(pi, cov(c), V) end end for i = 1:K pi = p[i] if pi > 0.0 c = component(d, i) - # todo: use more in-place operations - md = mean(c) - m - BLAS.axpy!(pi, md*md', V) + md .= mean(c) .- m + BLAS.syr!('U', pi, md, V) + end + end + for j in 1:length(d) + for i in (j+1):length(d) + V[i, j] = V[j, i] end end return V @@ -303,7 +307,7 @@ function _mixpdf!(r::AbstractArray, d::AbstractMixtureModel, x) else pdf!(t, component(d, i), x) end - BLAS.axpy!(pi, t, r) + r .= muladd.(pi, t, r) end end return r From 264c7f52da88f959fb38c70749d938ff60433a0e Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 12 Apr 2025 20:04:55 +0200 Subject: [PATCH 2/5] Use `mul!` --- src/mixtures/mixturemodel.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mixtures/mixturemodel.jl b/src/mixtures/mixturemodel.jl index 44353fc732..64a72460d3 100644 --- a/src/mixtures/mixturemodel.jl +++ b/src/mixtures/mixturemodel.jl @@ -186,7 +186,7 @@ function mean(d::MultivariateMixture) pi = p[i] if pi > 0.0 c = component(d, i) - m .= muladd.(pi, mean(c), m) + mul!(m, pi, mean(c), true, true) end end return m @@ -236,8 +236,8 @@ function cov(d::MultivariateMixture) pi = p[i] if pi > 0.0 c = component(d, i) - m .= muladd.(pi, mean(c), m) - V .= muladd.(pi, cov(c), V) + mul!(m, pi, mean(c), true, true) + mul!(V, pi, cov(c), true, true) end end for i = 1:K @@ -307,7 +307,7 @@ function _mixpdf!(r::AbstractArray, d::AbstractMixtureModel, x) else pdf!(t, component(d, i), x) end - r .= muladd.(pi, t, r) + mul!(r, pi, t, true, true) end end return r From 3581a1742f564531bcc5162d827fbb285c42f9cf Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 12 Apr 2025 22:03:27 +0200 Subject: [PATCH 3/5] Use `LinearAlgebra.copytri!` --- src/mixtures/mixturemodel.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/mixtures/mixturemodel.jl b/src/mixtures/mixturemodel.jl index 64a72460d3..650941221e 100644 --- a/src/mixtures/mixturemodel.jl +++ b/src/mixtures/mixturemodel.jl @@ -248,11 +248,7 @@ function cov(d::MultivariateMixture) BLAS.syr!('U', pi, md, V) end end - for j in 1:length(d) - for i in (j+1):length(d) - V[i, j] = V[j, i] - end - end + LinearAlgebra.copytri!(V, 'U') return V end From 63bd8673ddaa18eef99cdb5e15f8246917f5c123 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 14 Apr 2025 12:24:13 +0200 Subject: [PATCH 4/5] Update src/mixtures/mixturemodel.jl --- src/mixtures/mixturemodel.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mixtures/mixturemodel.jl b/src/mixtures/mixturemodel.jl index 6fe7d75762..7f36642900 100644 --- a/src/mixtures/mixturemodel.jl +++ b/src/mixtures/mixturemodel.jl @@ -245,7 +245,7 @@ function cov(d::MultivariateMixture) if pi > 0.0 c = component(d, i) md .= mean(c) .- m - BLAS.syr!('U', pi, md, V) + BLAS.syr!('U', Float64(pi), md, V) end end LinearAlgebra.copytri!(V, 'U') From a065a65a676c58afa1c734b710c49f30a2606ca3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 14 Apr 2025 19:32:57 +0200 Subject: [PATCH 5/5] Add more tests --- test/mixture.jl | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/test/mixture.jl b/test/mixture.jl index 0b25a2346a..d2b2742ebc 100644 --- a/test/mixture.jl +++ b/test/mixture.jl @@ -1,5 +1,6 @@ using Distributions, Random using Test +using LinearAlgebra using ForwardDiff: Dual @@ -252,17 +253,19 @@ end end @testset "Testing MultivariatevariateMixture" begin - g_m = MixtureModel( - IsoNormal[ MvNormal([0.0, 0.0], I), - MvNormal([0.2, 1.0], I), - MvNormal([-0.5, -3.0], 1.6 * I) ], - [0.2, 0.5, 0.3]) - @test isa(g_m, MixtureModel{Multivariate, Continuous, IsoNormal}) - @test length(components(g_m)) == 3 - @test length(g_m) == 2 - @test insupport(g_m, [0.0, 0.0]) == true - test_mixture(g_m, 1000, 10^6, rng) - test_params(g_m) + for T in (Float32, Float64) + g_m = MixtureModel( + IsoNormal[ MvNormal([0.0, 0.0], I), + MvNormal([0.2, 1.0], I), + MvNormal([-0.5, -3.0], 1.6 * I) ], + T[0.2, 0.5, 0.3]) + @test isa(g_m, MixtureModel{Multivariate, Continuous, IsoNormal}) + @test length(components(g_m)) == 3 + @test length(g_m) == 2 + @test insupport(g_m, [0.0, 0.0]) + test_mixture(g_m, 1000, 10^6, rng) + test_params(g_m) + end u1 = Uniform() u2 = Uniform(1.0, 2.0)