From 32a7781951e0f0e036a3d6a43783635caff120f2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 9 May 2025 13:38:14 +0530 Subject: [PATCH] fix: use `Diagonal` for diagonal mass matrices --- src/systems/diffeqs/abstractodesystem.jl | 5 +++++ src/systems/diffeqs/sdesystem.jl | 10 +++++++++- test/mass_matrix.jl | 21 ++++++++++++++++++++- 3 files changed, 34 insertions(+), 2 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index e66f03a85e..942e508644 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -284,6 +284,9 @@ function calculate_massmatrix(sys::AbstractODESystem; simplify = false) end M = simplify ? ModelingToolkit.simplify.(M) : M # M should only contain concrete numbers + if isdiag(M) + M = Diagonal(M) + end M == I ? I : M end @@ -410,6 +413,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, SparseArrays.sparse(M) elseif u0 === nothing || M === I M + elseif M isa Diagonal + Diagonal(ArrayInterface.restructure(u0, diag(M))) else ArrayInterface.restructure(u0 .* u0', M) end diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index c5299c28be..d743143e46 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -652,7 +652,15 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns( W_prototype = nothing end - _M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M) + _M = if sparse && !(u0 === nothing || M === I) + SparseArrays.sparse(M) + elseif u0 === nothing || M === I + M + elseif M isa Diagonal + Diagonal(ArrayInterface.restructure(u0, diag(M))) + else + ArrayInterface.restructure(u0 .* u0', M) + end observedfun = ObservedFunctionCache( sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false), cse) diff --git a/test/mass_matrix.jl b/test/mass_matrix.jl index 5183b4ab3f..8b31123834 100644 --- a/test/mass_matrix.jl +++ b/test/mass_matrix.jl @@ -1,4 +1,4 @@ -using OrdinaryDiffEq, ModelingToolkit, Test, LinearAlgebra +using OrdinaryDiffEq, ModelingToolkit, Test, LinearAlgebra, StaticArrays using ModelingToolkit: t_nounits as t, D_nounits as D, MTKParameters @variables y(t)[1:3] @@ -12,13 +12,18 @@ eqs = [D(y[1]) ~ -k[1] * y[1] + k[3] * y[2] * y[3], sys = complete(sys) @test_throws ArgumentError ODESystem(eqs, y[1]) M = calculate_massmatrix(sys) +@test M isa Diagonal @test M == [1 0 0 0 1 0 0 0 0] prob_mm = ODEProblem(sys, [y => [1.0, 0.0, 0.0]], (0.0, 1e5), [k => [0.04, 3e7, 1e4]]) +@test prob_mm.f.mass_matrix isa Diagonal{Float64, Vector{Float64}} sol = solve(prob_mm, Rodas5(), reltol = 1e-8, abstol = 1e-8) +prob_mm = ODEProblem(sys, SA[y => [1.0, 0.0, 0.0]], (0.0, 1e5), + [k => [0.04, 3e7, 1e4]]) +@test prob_mm.f.mass_matrix isa Diagonal{Float64, SVector{3, Float64}} function rober(du, u, p, t) y₁, y₂, y₃ = u @@ -43,3 +48,17 @@ eqs = [D(y[1]) ~ y[1], D(y[2]) ~ y[2], D(y[3]) ~ y[3]] @named sys = ODESystem(eqs, t, collect(y), [k]) @test calculate_massmatrix(sys) === I + +@testset "Mass matrix `isa Diagonal` for `SDEProblem`" begin + eqs = [D(y[1]) ~ -k[1] * y[1] + k[3] * y[2] * y[3], + D(y[2]) ~ k[1] * y[1] - k[3] * y[2] * y[3] - k[2] * y[2]^2, + 0 ~ y[1] + y[2] + y[3] - 1] + + @named sys = ODESystem(eqs, t, collect(y), [k]) + @named sys = SDESystem(sys, [1, 1, 0]) + sys = complete(sys) + prob = SDEProblem(sys, [y => [1.0, 0.0, 0.0]], (0.0, 1e5), [k => [0.04, 3e7, 1e4]]) + @test prob.f.mass_matrix isa Diagonal{Float64, Vector{Float64}} + prob = SDEProblem(sys, SA[y => [1.0, 0.0, 0.0]], (0.0, 1e5), [k => [0.04, 3e7, 1e4]]) + @test prob.f.mass_matrix isa Diagonal{Float64, SVector{3, Float64}} +end