Skip to content

Commit 8b8d549

Browse files
akashkgargAkash Garg
andauthored
Return jacobian_sparsity() when sparse && !jac (#968)
* Return jacobian_sparsity() when sparse && !jac Issue #871 * Convert jac_prototype to be eltype of u0 * Checking if u0 is nothing. Adding tests. * updating test to test sparsity pattern * adding missing import Co-authored-by: Akash Garg <[email protected]>
1 parent f12f472 commit 8b8d549

File tree

3 files changed

+76
-1
lines changed

3 files changed

+76
-1
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,12 +215,13 @@ function DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),
215215
end
216216
end
217217

218+
uElType = eltype(u0)
218219
ODEFunction{iip}(
219220
f,
220221
jac = _jac === nothing ? nothing : _jac,
221222
tgrad = _tgrad === nothing ? nothing : _tgrad,
222223
mass_matrix = _M,
223-
jac_prototype = sparse ? similar(get_jac(sys)[],Float64) : nothing,
224+
jac_prototype = (!isnothing(u0) && sparse) ? (!jac ? similar(jacobian_sparsity(sys),uElType) : similar(get_jac(sys)[],uElType)) : nothing,
224225
syms = Symbol.(states(sys)),
225226
indepsym = Symbol(independent_variable(sys)),
226227
observed = observedfun,

test/jacobiansparsity.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using OrdinaryDiffEq, ModelingToolkit, Test, SparseArrays
2+
3+
N = 32
4+
const xyd_brusselator = range(0,stop=1,length=N)
5+
brusselator_f(x, y, t) = (((x-0.3)^2 + (y-0.6)^2) <= 0.1^2) * (t >= 1.1) * 5.
6+
limit(a, N) = ModelingToolkit.ifelse(a == N+1, 1, ModelingToolkit.ifelse(a == 0, N, a))
7+
function brusselator_2d_loop(du, u, p, t)
8+
A, B, alpha, dx = p
9+
alpha = alpha/dx^2
10+
@inbounds for I in CartesianIndices((N, N))
11+
i, j = Tuple(I)
12+
x, y = xyd_brusselator[I[1]], xyd_brusselator[I[2]]
13+
ip1, im1, jp1, jm1 = limit(i+1, N), limit(i-1, N), limit(j+1, N), limit(j-1, N)
14+
du[i,j,1] = alpha*(u[im1,j,1] + u[ip1,j,1] + u[i,jp1,1] + u[i,jm1,1] - 4u[i,j,1]) +
15+
B + u[i,j,1]^2*u[i,j,2] - (A + 1)*u[i,j,1] + brusselator_f(x, y, t)
16+
du[i,j,2] = alpha*(u[im1,j,2] + u[ip1,j,2] + u[i,jp1,2] + u[i,jm1,2] - 4u[i,j,2]) +
17+
A*u[i,j,1] - u[i,j,1]^2*u[i,j,2]
18+
end
19+
end
20+
21+
# Test with tuple parameters
22+
p = (3.4, 1., 10., step(xyd_brusselator))
23+
24+
function init_brusselator_2d(xyd)
25+
N = length(xyd)
26+
u = zeros(N, N, 2)
27+
for I in CartesianIndices((N, N))
28+
x = xyd[I[1]]
29+
y = xyd[I[2]]
30+
u[I,1] = 22*(y*(1-y))^(3/2)
31+
u[I,2] = 27*(x*(1-x))^(3/2)
32+
end
33+
u
34+
end
35+
36+
u0 = init_brusselator_2d(xyd_brusselator)
37+
prob_ode_brusselator_2d = ODEProblem(brusselator_2d_loop,
38+
u0,(0.,11.5),p)
39+
sys = modelingtoolkitize(prob_ode_brusselator_2d)
40+
41+
# test sparse jacobian pattern only.
42+
prob = ODEProblem(sys, u0, (0, 11.5), sparse=true, jac=false)
43+
@test findnz(Symbolics.jacobian_sparsity(map(x->x.rhs, equations(sys)), states(sys)))[1:2] == findnz(prob.f.jac_prototype)[1:2]
44+
45+
# test sparse jacobian
46+
prob = ODEProblem(sys, u0, (0, 11.5), sparse=true, jac=true)
47+
@test findnz(calculate_jacobian(sys))[1:2] == findnz(prob.f.jac_prototype)[1:2]
48+
49+
# test when not sparse
50+
prob = ODEProblem(sys, u0, (0, 11.5), sparse=false, jac=true)
51+
@test prob.f.jac_prototype == nothing
52+
53+
prob = ODEProblem(sys, u0, (0, 11.5), sparse=false, jac=false)
54+
@test prob.f.jac_prototype == nothing
55+
56+
# test when u0 is nothing
57+
f = DiffEqBase.ODEFunction(sys, u0=nothing, sparse=true, jac=true)
58+
@test f.jac_prototype == nothing
59+
60+
f = DiffEqBase.ODEFunction(sys, u0=nothing, sparse=true, jac=false)
61+
@test f.jac_prototype == nothing
62+
63+
# test when u0 is not Float64
64+
u0 = similar(init_brusselator_2d(xyd_brusselator), Float32)
65+
prob_ode_brusselator_2d = ODEProblem(brusselator_2d_loop,
66+
u0,(0.,11.5),p)
67+
sys = modelingtoolkitize(prob_ode_brusselator_2d)
68+
69+
prob = ODEProblem(sys, u0, (0, 11.5), sparse=true, jac=false)
70+
@test eltype(prob.f.jac_prototype) == Float32
71+
72+
prob = ODEProblem(sys, u0, (0, 11.5), sparse=true, jac=true)
73+
@test eltype(prob.f.jac_prototype) == Float32

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ using SafeTestsets, Test
3030
@safetestset "Precompiled Modules Test" begin include("precompile_test.jl") end
3131
@testset "Distributed Test" begin include("distributed.jl") end
3232
@safetestset "Variable Utils Test" begin include("variable_utils.jl") end
33+
@safetestset "Jacobian Sparsity" begin include("jacobiansparsity.jl") end
3334
println("Last test requires gcc available in the path!")
3435
@safetestset "C Compilation Test" begin include("ccompile.jl") end
3536
@safetestset "Latexify recipes Test" begin include("latexify.jl") end

0 commit comments

Comments
 (0)