Skip to content

Commit 7cc075d

Browse files
committed
Jacobian sparsity for ODEProblem
1 parent e477c2f commit 7cc075d

File tree

4 files changed

+32
-5
lines changed

4 files changed

+32
-5
lines changed

src/structural_transformation/StructuralTransformations.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@ using ModelingToolkit: ODESystem, AbstractSystem,var_from_nested_derivative, Dif
1616
isdiffeq, isdifferential, isinput,
1717
empty_substitutions, get_substitutions,
1818
get_structure, get_iv, independent_variables,
19-
get_structure, defaults, InvalidSystemException,
19+
has_structure, defaults, InvalidSystemException,
2020
ExtraEquationsSystemException,
2121
ExtraVariablesSystemException,
2222
get_postprocess_fbody, vars!,
23-
IncrementalCycleTracker, add_edge_checked!, topological_sort
23+
IncrementalCycleTracker, add_edge_checked!, topological_sort,
24+
invalidate_cache!
2425

2526
using ModelingToolkit.BipartiteGraphs
2627
import .BipartiteGraphs: invview
@@ -41,6 +42,7 @@ export tearing, dae_index_lowering, check_consistency
4142
export tearing_assignments, tearing_substitution
4243
export build_torn_function, build_observed_function, ODAEProblem
4344
export sorted_incidence_matrix
45+
export torn_system_jacobian_sparsity
4446

4547
include("utils.jl")
4648
include("pantelides.jl")

src/structural_transformation/codegen.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using ModelingToolkit: isdifferenceeq, has_continuous_events, generate_rootfindi
44

55
const MAX_INLINE_NLSOLVE_SIZE = 8
66

7-
function torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs)
7+
function torn_system_with_nlsolve_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs)
88
s = structure(sys)
99
@unpack fullvars, graph = s
1010

@@ -329,7 +329,7 @@ function build_torn_function(
329329

330330
ODEFunction{true}(
331331
@RuntimeGeneratedFunction(expr),
332-
sparsity = jacobian_sparsity ? torn_system_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs) : nothing,
332+
sparsity = jacobian_sparsity ? torn_system_with_nlsolve_jacobian_sparsity(sys, var_eq_matching, var_sccs, nlsolve_scc_idxs, eqs_idxs, states_idxs) : nothing,
333333
syms = syms,
334334
observed = observedfun,
335335
mass_matrix = mass_matrix,

src/structural_transformation/utils.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,25 @@ function uneven_invmap(n::Int, list)
234234
return rename
235235
end
236236

237+
function torn_system_jacobian_sparsity(sys)
238+
has_structure(sys) || return nothing
239+
s = structure(sys)
240+
@unpack fullvars, graph = s
241+
242+
states_idxs = findall(!isdifferential, fullvars)
243+
var2idx = Dict{Int,Int}(v => i for (i, v) in enumerate(states_idxs))
244+
I = Int[]; J = Int[]
245+
for ieq in 𝑠vertices(graph)
246+
for ivar in 𝑠neighbors(graph, ieq)
247+
nivar = get(var2idx, ivar, 0)
248+
nivar == 0 && continue
249+
push!(I, ieq)
250+
push!(J, nivar)
251+
end
252+
end
253+
return sparse(I, J, true)
254+
end
255+
237256
###
238257
### Nonlinear equation(s) solving
239258
###

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,9 +281,13 @@ function calculate_massmatrix(sys::AbstractODESystem; simplify=false)
281281
M == I ? I : M
282282
end
283283

284-
jacobian_sparsity(sys::AbstractODESystem) =
284+
function jacobian_sparsity(sys::AbstractODESystem)
285+
sparsity = torn_system_jacobian_sparsity(sys)
286+
sparsity === nothing || return sparsity
287+
285288
jacobian_sparsity([eq.rhs for eq equations(sys)],
286289
[dv for dv in states(sys)])
290+
end
287291

288292
function isautonomous(sys::AbstractODESystem)
289293
tgrad = calculate_tgrad(sys;simplify=true)
@@ -319,6 +323,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
319323
eval_module = @__MODULE__,
320324
steady_state = false,
321325
checkbounds=false,
326+
sparsity=false,
322327
kwargs...) where {iip}
323328

324329
f_gen = generate_function(sys, dvs, ps; expression=Val{eval_expression}, expression_module=eval_module, checkbounds=checkbounds, kwargs...)
@@ -394,6 +399,7 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
394399
syms = Symbol.(states(sys)),
395400
indepsym = Symbol(get_iv(sys)),
396401
observed = observedfun,
402+
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
397403
)
398404
end
399405

0 commit comments

Comments
 (0)