Skip to content

Commit 832d692

Browse files
working ODE jacobians
1 parent c28bb73 commit 832d692

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

src/systems/diffeqs/diffeqsystem.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,11 @@ function build_equals_expr(eq)
5555
end
5656
end
5757

58-
function generate_ode_jacobian(sys::DiffEqSystem,simplify=true)
59-
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in 1:length(sys.dvs)]
60-
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
58+
function calculate_jacobian(sys::DiffEqSystem,simplify=true)
6159
diff_idxs = map(eq->eq.args[1].diff !=nothing,sys.eqs)
6260
diff_exprs = sys.eqs[diff_idxs]
6361
rhs = [eq.args[2] for eq in diff_exprs]
62+
# Handle intermediate calculations by substitution
6463
calcs = sys.eqs[.!(diff_idxs)]
6564
for i in 1:length(calcs)
6665
find_replace!.(rhs,calcs[i].args[1],calcs[i].args[2])
@@ -73,6 +72,18 @@ function generate_ode_jacobian(sys::DiffEqSystem,simplify=true)
7372
sys_exprs
7473
end
7574

75+
function generate_ode_jacobian(sys::DiffEqSystem,simplify=true)
76+
var_exprs = [:($(sys.dvs[i].name) = u[$i]) for i in 1:length(sys.dvs)]
77+
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
78+
diff_idxs = map(eq->eq.args[1].diff !=nothing,sys.eqs)
79+
diff_exprs = sys.eqs[diff_idxs]
80+
jac = calculate_jacobian(sys,simplify)
81+
jac_exprs = [:(J[$i,$j] = $(Expr(jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
82+
exprs = vcat(var_exprs,param_exprs,vec(jac_exprs))
83+
block = expr_arr_to_block(exprs)
84+
:((J,u,p,t)->$(block))
85+
end
86+
7687
function DiffEqBase.DiffEqFunction(sys::DiffEqSystem)
7788
expr = generate_ode_function(sys)
7889
DiffEqFunction{true}(eval(expr))

test/system_construction.jl

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,15 @@ eqs = [D*x ~ σ*(y-x),
1515
D*z ~ x*y - β*z]
1616
de = DiffEqSystem(eqs,[t],[x,y,z],Variable[],[σ,ρ,β])
1717
SciCompDSL.generate_ode_function(de)
18-
jac = SciCompDSL.generate_ode_jacobian(de,false)
19-
jac = SciCompDSL.generate_ode_jacobian(de)
18+
jac_expr = SciCompDSL.generate_ode_jacobian(de)
19+
jac = SciCompDSL.calculate_jacobian(de)
2020
f = DiffEqFunction(de)
2121
I - jac
22-
@test_broken inv(jac)
22+
@test_broken inv(I - jac)
2323

2424
# Differential equation with automatic extraction of variables on rhs
2525
de2 = DiffEqSystem(eqs, [t])
2626

27-
2827
function test_vars_extraction(de, de2)
2928
for el in (:ivs, :dvs, :vs, :ps)
3029
names2 = sort(collect(var.name for var in getfield(de2,el)))
@@ -67,8 +66,7 @@ eqs = [a ~ y-x,
6766
D*z ~ x*y - β*z]
6867
de = DiffEqSystem(eqs,[t],[x,y,z],[a],[σ,ρ,β])
6968
SciCompDSL.generate_ode_function(de)
70-
jac = SciCompDSL.generate_ode_jacobian(de,false)
71-
jac = SciCompDSL.generate_ode_jacobian(de)
69+
jac = SciCompDSL.calculate_jacobian(de)
7270
f = DiffEqFunction(de)
7371

7472
# Define a nonlinear system
@@ -95,7 +93,6 @@ eqs = [0 ~ σ*(y-x),
9593
0 ~ x*y - β*z]
9694
ns = NonlinearSystem(eqs)
9795
nlsys_func = SciCompDSL.generate_nlsys_function(ns)
98-
jac = SciCompDSL.generate_nlsys_jacobian(ns,false)
9996
jac = SciCompDSL.generate_nlsys_jacobian(ns)
10097
f = @eval eval(nlsys_func)
10198

@@ -107,5 +104,4 @@ eqs = [a ~ y-x,
107104
0 ~ x*y - β*z]
108105
ns = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β])
109106
nlsys_func = SciCompDSL.generate_nlsys_function(ns)
110-
jac = SciCompDSL.generate_nlsys_jacobian(ns,false)
111107
jac = SciCompDSL.generate_nlsys_jacobian(ns)

0 commit comments

Comments
 (0)