Skip to content

Commit df73130

Browse files
authored
Merge pull request #1036 from SciML/myb/jac
Compute Jacobian sparsity when u0 is not provided
2 parents 9ae858e + 02d5949 commit df73130

File tree

4 files changed

+37
-17
lines changed

4 files changed

+37
-17
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ end
1919

2020
function calculate_jacobian(sys::AbstractODESystem;
2121
sparse=false, simplify=false)
22-
isempty(get_jac(sys)[]) || return get_jac(sys)[] # use cached Jacobian, if possible
22+
cache = get_jac(sys)[]
23+
if cache isa Tuple && cache[2] == (sparse, simplify)
24+
return cache[1]
25+
end
2326
rhs = [eq.rhs for eq equations(sys)]
2427

2528
iv = get_iv(sys)
@@ -31,7 +34,7 @@ function calculate_jacobian(sys::AbstractODESystem;
3134
jac = jacobian(rhs, dvs, simplify=simplify)
3235
end
3336

34-
get_jac(sys)[] = jac # cache Jacobian
37+
get_jac(sys)[] = jac, (sparse, simplify) # cache Jacobian
3538
return jac
3639
end
3740

@@ -217,13 +220,22 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
217220
end
218221
end
219222

220-
uElType = eltype(u0)
223+
jac_prototype = if sparse
224+
uElType = u0 === nothing ? Float64 : eltype(u0)
225+
if jac
226+
similar(calculate_jacobian(sys, sparse=sparse), uElType)
227+
else
228+
similar(jacobian_sparsity(sys), uElType)
229+
end
230+
else
231+
nothing
232+
end
221233
ODEFunction{iip}(
222234
f,
223235
jac = _jac === nothing ? nothing : _jac,
224236
tgrad = _tgrad === nothing ? nothing : _tgrad,
225237
mass_matrix = _M,
226-
jac_prototype = (!isnothing(u0) && sparse) ? (!jac ? similar(jacobian_sparsity(sys),uElType) : SparseArrays.sparse(similar(get_jac(sys)[],uElType))) : nothing,
238+
jac_prototype = jac_prototype,
227239
syms = Symbol.(states(sys)),
228240
indepsym = Symbol(independent_variable(sys)),
229241
observed = observedfun,

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,20 @@ function NonlinearSystem(eqs, states, ps;
7676
NonlinearSystem(eqs, value.(states), value.(ps), observed, jac, name, systems, defaults, nothing, connection_type)
7777
end
7878

79-
function calculate_jacobian(sys::NonlinearSystem;sparse=false,simplify=false)
79+
function calculate_jacobian(sys::NonlinearSystem; sparse=false, simplify=false)
80+
cache = get_jac(sys)[]
81+
if cache isa Tuple && cache[2] == (sparse, simplify)
82+
return cache[1]
83+
end
84+
8085
rhs = [eq.rhs for eq equations(sys)]
8186
vals = [dv for dv in states(sys)]
8287
if sparse
8388
jac = sparsejacobian(rhs, vals, simplify=simplify)
8489
else
8590
jac = jacobian(rhs, vals, simplify=simplify)
8691
end
87-
get_jac(sys)[] = jac
92+
get_jac(sys)[] = jac, (sparse, simplify)
8893
return jac
8994
end
9095

@@ -170,7 +175,7 @@ function DiffEqBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = states(sy
170175

171176
NonlinearFunction{iip}(f,
172177
jac = _jac === nothing ? nothing : _jac,
173-
jac_prototype = sparse ? similar(sys.jac[],Float64) : nothing,
178+
jac_prototype = sparse ? similar(calculate_jacobian(sys, sparse=sparse),Float64) : nothing,
174179
syms = Symbol.(states(sys)), observed = observedfun)
175180
end
176181

test/jacobiansparsity.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using OrdinaryDiffEq, ModelingToolkit, Test, SparseArrays
22

3-
N = 32
4-
const xyd_brusselator = range(0,stop=1,length=N)
3+
N = 3
4+
xyd_brusselator = range(0,stop=1,length=N)
55
brusselator_f(x, y, t) = (((x-0.3)^2 + (y-0.6)^2) <= 0.1^2) * (t >= 1.1) * 5.
66
limit(a, N) = ModelingToolkit.ifelse(a == N+1, 1, ModelingToolkit.ifelse(a == 0, N, a))
77
function brusselator_2d_loop(du, u, p, t)
@@ -40,11 +40,12 @@ sys = modelingtoolkitize(prob_ode_brusselator_2d)
4040

4141
# test sparse jacobian pattern only.
4242
prob = ODEProblem(sys, u0, (0, 11.5), sparse=true, jac=false)
43-
@test findnz(Symbolics.jacobian_sparsity(map(x->x.rhs, equations(sys)), states(sys)))[1:2] == findnz(prob.f.jac_prototype)[1:2]
43+
JP = prob.f.jac_prototype
44+
@test findnz(Symbolics.jacobian_sparsity(map(x->x.rhs, equations(sys)), states(sys)))[1:2] == findnz(JP)[1:2]
4445

4546
# test sparse jacobian
4647
prob = ODEProblem(sys, u0, (0, 11.5), sparse=true, jac=true)
47-
@test findnz(calculate_jacobian(sys))[1:2] == findnz(prob.f.jac_prototype)[1:2]
48+
@test findnz(calculate_jacobian(sys, sparse=true))[1:2] == findnz(prob.f.jac_prototype)[1:2]
4849

4950
# test when not sparse
5051
prob = ODEProblem(sys, u0, (0, 11.5), sparse=false, jac=true)
@@ -55,10 +56,12 @@ prob = ODEProblem(sys, u0, (0, 11.5), sparse=false, jac=false)
5556

5657
# test when u0 is nothing
5758
f = DiffEqBase.ODEFunction(sys, u0=nothing, sparse=true, jac=true)
58-
@test f.jac_prototype == nothing
59+
@test findnz(f.jac_prototype)[1:2] == findnz(JP)[1:2]
60+
@test eltype(f.jac_prototype) == Float64
5961

6062
f = DiffEqBase.ODEFunction(sys, u0=nothing, sparse=true, jac=false)
61-
@test f.jac_prototype == nothing
63+
@test findnz(f.jac_prototype)[1:2] == findnz(JP)[1:2]
64+
@test eltype(f.jac_prototype) == Float64
6265

6366
# test when u0 is not Float64
6467
u0 = similar(init_brusselator_2d(xyd_brusselator), Float32)

test/nonlinearsystem.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,21 +117,21 @@ eqs = [0 ~ σ*(y-x),
117117
0 ~ x*y - β*z]
118118
ns = NonlinearSystem(eqs, [x,y,z], [σ,ρ,β])
119119
np = NonlinearProblem(ns, [0,0,0], [1,2,3], jac=true, sparse=true)
120-
@test ModelingToolkit.get_jac(ns)[] isa SparseMatrixCSC
120+
@test calculate_jacobian(ns, sparse=true) isa SparseMatrixCSC
121121

122122
# issue #819
123123
@testset "Combined system name collisions" begin
124124
function makesys(name)
125125
@parameters a
126126
@variables x f
127-
127+
128128
NonlinearSystem([0 ~ -a*x + f],[x,f],[a], name=name)
129129
end
130-
130+
131131
function issue819()
132132
sys1 = makesys(:sys1)
133133
sys2 = makesys(:sys1)
134134
@test_throws ArgumentError NonlinearSystem([sys2.f ~ sys1.x, sys1.f ~ 0],[],[], systems=[sys1, sys2])
135135
end
136136
issue819()
137-
end
137+
end

0 commit comments

Comments
 (0)