diff --git a/src/helpers/algebra/common.jl b/src/helpers/algebra/common.jl index 9d6f54874..1a09a953e 100644 --- a/src/helpers/algebra/common.jl +++ b/src/helpers/algebra/common.jl @@ -99,7 +99,7 @@ function mul_trace(A::AbstractMatrix, B::AbstractMatrix) n = first(sA) for i in 1:n for j in 1:n - result += A[i, j] * B[j, i] + @inbounds result += A[i, j] * B[j, i] end end return result diff --git a/src/rules/continuous_transition/W.jl b/src/rules/continuous_transition/W.jl index c6afa6d1d..5159a6a13 100644 --- a/src/rules/continuous_transition/W.jl +++ b/src/rules/continuous_transition/W.jl @@ -8,10 +8,20 @@ function compute_delta(my, Vy, mx, Vx, Vyx, mA, Va, ma, Fs) G₅ = zeros(eltype(ma), dy, dy) G₆ = zeros(eltype(ma), dy, dy) mamat = ma * ma' - for (i, j) in Iterators.product(1:dy, 1:dy) - tmp = Fs[i]' * Ex_xx * Fs[j] - G₅[i, j] = mul_trace(tmp, mamat) - G₆[i, j] = mul_trace(tmp, Va) + + Y = similar(Ex_xx) + Z = similar(Ex_xx) + + @inbounds for (i, j) in Iterators.product(1:dy, 1:dy) + mul!(Y, Ex_xx, Fs[j]) + mul!(Z, Fs[i]', Y) + + G₅[i, j] = mul_trace(Z, mamat) + G₆[i, j] = mul_trace(Z, Va) + + # tmp = Fs[i]' * Ex_xx * Fs[j] + # G₅[i, j] = mul_trace(tmp, mamat) + # G₆[i, j] = mul_trace(tmp, Va) end G = G₁ - (G₂ + G₃) .+ Symmetric(G₅ + G₆) diff --git a/src/rules/continuous_transition/a.jl b/src/rules/continuous_transition/a.jl index d3dd2f87b..1ce867a61 100644 --- a/src/rules/continuous_transition/a.jl +++ b/src/rules/continuous_transition/a.jl @@ -17,10 +17,16 @@ Vxymxy = rank1update(Vyx', mx, my) Vxmx = rank1update(Vx, mx) + + Y = similar(Vxmx) + for i in 1:dy xi += Fs[i]' * Vxymxy * mW[:, i] for j in 1:dy - W += mW[j, i] * Fs[i]' * Vxmx * Fs[j] + mul!(Y, Vxmx, Fs[j]) + mul!(W, Fs[i]', Y, mW[j, i], 1) + + # W += mW[j, i] * Fs[i]' * Vxmx * Fs[j] end end diff --git a/src/rules/continuous_transition/marginals.jl b/src/rules/continuous_transition/marginals.jl index abd21539b..3d1b9ac9b 100644 --- a/src/rules/continuous_transition/marginals.jl +++ b/src/rules/continuous_transition/marginals.jl @@ -25,8 +25,13 @@ function continuous_tranition_marginal(m_y::MultivariateNormalDistributionsFamil W_21 = negate_inplace!(mA' * mW) Ξ = Wx + + Y = similar(Va) for (i, j) in Iterators.product(1:dy, 1:dy) - Ξ += mW[j, i] * Fs[j] * Va * Fs[i]' + mul!(Y, Va, Fs[i]') + mul!(Ξ, Fs[j], Y, mW[j, i], 1) + + # Ξ += mW[j, i] * Fs[j] * Va * Fs[i]' end W_22 = Ξ + mA' * mW * mA diff --git a/src/rules/continuous_transition/x.jl b/src/rules/continuous_transition/x.jl index 20a917ebe..79a32a6cc 100644 --- a/src/rules/continuous_transition/x.jl +++ b/src/rules/continuous_transition/x.jl @@ -15,8 +15,13 @@ WymW = Wy - Wy * cholinv(Wy + mW) * Wy Ξ = mA' * WymW * mA + Y = similar(Va) + for (i, j) in Iterators.product(1:dy, 1:dy) - Ξ += mW[j, i] * Fs[j] * Va * Fs[i]' + mul!(Y, Va, Fs[i]') + mul!(Ξ, Fs[j], Y, mW[j, i], 1) + + # Ξ += mW[j, i] * Fs[j] * Va * Fs[i]' end z = mA' * WymW * my