Skip to content

Commit bd8913b

Browse files
committed
feat: add W sparsity/generation functions
1 parent a71ba77 commit bd8913b

File tree

2 files changed

+49
-10
lines changed

2 files changed

+49
-10
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,33 @@ function generate_jacobian(sys::AbstractODESystem, dvs = unknowns(sys),
126126
dvs,
127127
p...,
128128
get_iv(sys);
129+
wrap_code = sparse ? assert_jac_length_header(sys) : (identity, identity),
130+
kwargs...)
131+
end
132+
133+
function assert_jac_length_header(sys)
134+
W = W_sparsity(sys)
135+
136+
identity, expr -> Func([expr.args...], [], LiteralExpr(quote
137+
@assert nnz($(expr.args[1])) == nnz(W)
138+
expr.body
139+
end))
140+
end
141+
142+
function generate_W(sys::AbstractODESystem, γ = 1., dvs = unknowns(sys),
143+
ps = parameters(sys; initial_parameters = true);
144+
simplify = false, sparse = false, kwargs...)
145+
@variables ˍ₋gamma
146+
M = calculate_massmatrix(sys; simplify)
147+
J = calculate_jacobian(sys; simplify, sparse, dvs)
148+
W = ˍ₋gamma*M + J
149+
150+
p = reorder_parameters(sys, ps)
151+
return build_function_wrapper(sys, W,
152+
dvs,
153+
p...,
154+
ˍ₋gamma,
155+
get_iv(sys);
129156
kwargs...)
130157
end
131158

@@ -264,6 +291,12 @@ function jacobian_dae_sparsity(sys::AbstractODESystem)
264291
J1 + J2
265292
end
266293

294+
function W_sparsity(sys::AbstractODESystem)
295+
jac_sparsity = jacobian_sparsity(sys)
296+
M_sparsity = sparse(iszero.(calculate_massmatrix(sys)))
297+
jac_sparsity .|| M_sparsity
298+
end
299+
267300
function isautonomous(sys::AbstractODESystem)
268301
tgrad = calculate_tgrad(sys; simplify = true)
269302
all(iszero, tgrad)
@@ -368,15 +401,17 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
368401
observedfun = ObservedFunctionCache(
369402
sys; steady_state, eval_expression, eval_module, checkbounds, cse)
370403

371-
jac_prototype = if sparse
404+
if sparse
372405
uElType = u0 === nothing ? Float64 : eltype(u0)
373406
if jac
374-
similar(calculate_jacobian(sys, sparse = sparse), uElType)
407+
jac_prototype = similar(calculate_jacobian(sys; sparse), uElType)
375408
else
376-
similar(jacobian_sparsity(sys), uElType)
409+
jac_prototype = similar(jacobian_sparsity(sys), uElType)
377410
end
411+
W_prototype = similar(W_sparsity(sys), uElType)
378412
else
379-
nothing
413+
jac_prototype = nothing
414+
W_prototype = nothing
380415
end
381416

382417
@set! sys.split_idxs = split_idxs
@@ -386,7 +421,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
386421
jac = _jac === nothing ? nothing : _jac,
387422
tgrad = _tgrad === nothing ? nothing : _tgrad,
388423
mass_matrix = _M,
389-
jac_prototype = jac_prototype,
424+
jac_prototype = W_prototype,
425+
W_prototype = W_prototype,
390426
observed = observedfun,
391427
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
392428
analytic = analytic,

src/systems/diffeqs/sdesystem.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -642,15 +642,17 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
642642
_Wfact, _Wfact_t = nothing, nothing
643643
end
644644

645-
jac_prototype = if sparse
645+
if sparse
646646
uElType = u0 === nothing ? Float64 : eltype(u0)
647647
if jac
648-
similar(calculate_jacobian(sys, sparse = sparse), uElType)
648+
jac_prototype = similar(calculate_jacobian(sys; sparse), uElType)
649649
else
650-
similar(jacobian_sparsity(sys), uElType)
650+
jac_prototype = similar(jacobian_sparsity(sys), uElType)
651651
end
652+
W_prototype = similar(W_sparsity(sys), uElType)
652653
else
653-
nothing
654+
jac_prototype = nothing
655+
W_prototype = nothing
654656
end
655657

656658
M = calculate_massmatrix(sys)
@@ -664,7 +666,8 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
664666
jac = _jac === nothing ? nothing : _jac,
665667
tgrad = _tgrad === nothing ? nothing : _tgrad,
666668
mass_matrix = _M,
667-
jac_prototype = jac_prototype,
669+
jac_prototype = W_prototype,
670+
W_prototype = W_prototype,
668671
observed = observedfun,
669672
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
670673
analytic = analytic,

0 commit comments

Comments
 (0)