Skip to content

Commit 753cae3

Browse files
refactor: centralize mass matrix and W sparsity handling
1 parent 4ee32df commit 753cae3

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

src/problems/odeproblem.jl

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,13 @@
3535
end
3636

3737
M = calculate_massmatrix(sys)
38-
39-
_M = if sparse && !(u0 === nothing || M === I)
40-
SparseArrays.sparse(M)
41-
elseif u0 === nothing || M === I
42-
M
43-
else
44-
ArrayInterface.restructure(u0 .* u0', M)
45-
end
38+
_M = concrete_massmatrix(M; sparse, u0)
4639

4740
observedfun = ObservedFunctionCache(
4841
sys; steady_state, eval_expression, eval_module, checkbounds, cse)
4942

50-
if sparse
51-
uElType = u0 === nothing ? Float64 : eltype(u0)
52-
W_prototype = similar(W_sparsity(sys), uElType)
53-
else
54-
W_prototype = nothing
55-
end
43+
_W_sparsity = W_sparsity(sys)
44+
W_prototype = calculate_W_prototype(_W_sparsity; u0, sparse)
5645

5746
ODEFunction{iip, spec}(f;
5847
sys = sys,
@@ -61,7 +50,7 @@
6150
mass_matrix = _M,
6251
jac_prototype = W_prototype,
6352
observed = observedfun,
64-
sparsity = sparsity ? W_sparsity(sys) : nothing,
53+
sparsity = sparsity ? _W_sparsity : nothing,
6554
analytic = analytic,
6655
initialization_data)
6756
end

src/systems/codegen.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,16 @@ function calculate_massmatrix(sys::System; simplify = false)
240240
M == I ? I : M
241241
end
242242

243+
function concrete_massmatrix(M; sparse = false, u0 = nothing)
244+
if sparse && !(u0 === nothing || M === I)
245+
SparseArrays.sparse(M)
246+
elseif u0 === nothing || M === I
247+
M
248+
else
249+
ArrayInterface.restructure(u0 .* u0', M)
250+
end
251+
end
252+
243253
function jacobian_sparsity(sys::System)
244254
sparsity = torn_system_jacobian_sparsity(sys)
245255
sparsity === nothing || return sparsity
@@ -266,6 +276,12 @@ function W_sparsity(sys::System)
266276
jac_sparsity .| M_sparsity
267277
end
268278

279+
function calculate_W_prototype(W_sparsity; u0 = nothing, sparse = false)
280+
sparse || return nothing
281+
uElType = u0 === nothing ? Float64 : eltype(u0)
282+
return similar(W_sparsity, uElType)
283+
end
284+
269285
function isautonomous(sys::System)
270286
tgrad = calculate_tgrad(sys; simplify = true)
271287
all(iszero, tgrad)

0 commit comments

Comments
 (0)