Skip to content

Commit 1b37be2

Browse files
Merge pull request #1423 from SciML/ChrisRackauckas-patch-4-1
If sparse, make the mass matrix sparse
2 parents 274c80c + fd75481 commit 1b37be2

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ function calculate_massmatrix(sys::AbstractODESystem; simplify=false)
282282
end
283283
M = simplify ? ModelingToolkit.simplify.(M) : M
284284
# M should only contain concrete numbers
285-
M == I ? I : M
285+
M === I ? I : M
286286
end
287287

288288
jacobian_sparsity(sys::AbstractODESystem) =
@@ -355,8 +355,14 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
355355
end
356356

357357
M = calculate_massmatrix(sys)
358-
359-
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
358+
359+
_M = if sparse && !(u0 === nothing || M === I)
360+
SparseArrays.sparse(M)
361+
elseif u0 === nothing || M === I
362+
M
363+
else
364+
ArrayInterface.restructure(u0 .* u0',M)
365+
end
360366

361367
obs = observed(sys)
362368
observedfun = if steady_state
@@ -509,7 +515,13 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
509515

510516
M = calculate_massmatrix(sys)
511517

512-
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0',M)
518+
_M = if sparse && !(u0 === nothing || M === I)
519+
SparseArrays.sparse(M)
520+
elseif u0 === nothing || M === I
521+
M
522+
else
523+
ArrayInterface.restructure(u0 .* u0',M)
524+
end
513525

514526
jp_expr = sparse ? :(similar($(get_jac(sys)[]),Float64)) : :nothing
515527
ex = quote

test/odesystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ D = Differential(t)
356356
eqs = [D(x1) ~ -x1]
357357
@named sys = ODESystem(eqs,t,[x1,x2],[])
358358
@test_throws ArgumentError ODEProblem(sys, [1.0,1.0], (0.0,1.0))
359-
prob = ODEProblem(sys, [1.0,1.0], (0.0,1.0), check_length=false)
359+
@test_throws DimensionMismatch ODEProblem(sys, [1.0,1.0], (0.0,1.0), check_length=false)
360360

361361
# check inputs
362362
let

0 commit comments

Comments
 (0)