Skip to content

Commit 9665d46

Browse files
working nlsys jacobians
1 parent 832d692 commit 9665d46

File tree

4 files changed

+15
-6
lines changed

4 files changed

+15
-6
lines changed

src/SciCompDSL.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ Base.promote_rule(::Type{T},::Type{T2}) where {T<:Number,T2<:Expression} = Expre
1515
Base.one(::Type{T}) where T<:Expression = Constant(1)
1616
Base.zero(::Type{T}) where T<:Expression = Constant(0)
1717

18+
function caclulate_jacobian end
19+
1820
include("operations.jl")
1921
include("operators.jl")
2022
include("systems/diffeqs/diffeqsystem.jl")

src/systems/nonlinear/nonlinear_system.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,10 @@ function generate_nlsys_function(sys::NonlinearSystem)
3737
:((du,u,p)->$(block))
3838
end
3939

40-
function generate_nlsys_jacobian(sys::NonlinearSystem,simplify=true)
41-
var_exprs = [:($(sys.vs[i].name) = u[$i]) for i in 1:length(sys.vs)]
42-
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
43-
40+
function calculate_jacobian(sys::NonlinearSystem,simplify=true)
4441
sys_idxs = map(eq->isequal(eq.args[1],Constant(0)),sys.eqs)
4542
sys_eqs = sys.eqs[sys_idxs]
4643
calc_eqs = sys.eqs[.!(sys_idxs)]
47-
sys_exprs = [:($(Symbol("resid[$i]")) = $(sys_eqs[i].args[2])) for i in eachindex(sys_eqs)]
4844
rhs = [eq.args[2] for eq in sys_eqs]
4945

5046
for i in 1:length(calc_eqs)
@@ -59,5 +55,15 @@ function generate_nlsys_jacobian(sys::NonlinearSystem,simplify=true)
5955
sys_exprs
6056
end
6157

58+
function generate_nlsys_jacobian(sys::NonlinearSystem,simplify=true)
59+
var_exprs = [:($(sys.vs[i].name) = u[$i]) for i in 1:length(sys.vs)]
60+
param_exprs = [:($(sys.ps[i].name) = p[$i]) for i in 1:length(sys.ps)]
61+
jac = calculate_jacobian(sys,simplify)
62+
jac_exprs = [:(J[$i,$j] = $(Expr(jac[i,j]))) for i in 1:size(jac,1), j in 1:size(jac,2)]
63+
exprs = vcat(var_exprs,param_exprs,vec(jac_exprs))
64+
block = expr_arr_to_block(exprs)
65+
:((J,u,p,t)->$(block))
66+
end
67+
6268
export NonlinearSystem
6369
export generate_nlsys_function

test/derivatives.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ eqs = [0 ~ σ*(y-x),
3030
0 ~ x*-z)-y,
3131
0 ~ x*y - β*z]
3232
sys = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β])
33-
jac = SciCompDSL.generate_nlsys_jacobian(sys)
33+
jac = SciCompDSL.calculate_jacobian(sys)
3434
@test jac[1,1] == σ*-1
3535
@test jac[1,2] == σ
3636
@test jac[1,3] == 0

test/system_construction.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,5 @@ eqs = [a ~ y-x,
104104
0 ~ x*y - β*z]
105105
ns = NonlinearSystem(eqs,[x,y,z],[σ,ρ,β])
106106
nlsys_func = SciCompDSL.generate_nlsys_function(ns)
107+
jac = SciCompDSL.calculate_jacobian(ns)
107108
jac = SciCompDSL.generate_nlsys_jacobian(ns)

0 commit comments

Comments
 (0)