Skip to content

Commit 7ab479c

Browse files
Merge pull request #2480 from AayushSabharwal/as/nonlinear-jacobian
fix: fix jacobian generation for NonlinearSystem
2 parents aac3d1f + 2a7f190 commit 7ab479c

File tree

7 files changed

+19
-14
lines changed

7 files changed

+19
-14
lines changed

docs/src/basics/Composition.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ N = S + I + R
204204
205205
@named seqn = ODESystem([D(S) ~ -β * S * I / N], t)
206206
@named ieqn = ODESystem([D(I) ~ β * S * I / N - γ * I], t)
207-
@named reqn = ODESystem([D(R) ~ γ * I],t )
207+
@named reqn = ODESystem([D(R) ~ γ * I], t)
208208
209209
sir = compose(
210210
ODESystem(

docs/src/examples/higher_order.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ eqs = [D(D(x)) ~ σ * (y - x),
2222
D(y) ~ x * (ρ - z) - y,
2323
D(z) ~ x * y - β * z]
2424
25-
@named sys = ODESystem(eqs,t)
25+
@named sys = ODESystem(eqs, t)
2626
```
2727

2828
Note that we could've used an alternative syntax for 2nd order, i.e.

docs/src/tutorials/bifurcation_diagram_computation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ It is also possible to use `ODESystem`s (rather than `NonlinearSystem`s) as inpu
9191
using BifurcationKit, ModelingToolkit, Plots
9292
using ModelingToolkit: t_nounits as t, D_nounits as D
9393
94-
@variables x(t) y(t)
94+
@variables x(t) y(t)
9595
@parameters μ
9696
eqs = [D(x) ~ μ * x - y - x * (x^2 + y^2),
9797
D(y) ~ x + μ * y - y * (x^2 + y^2)]

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,8 @@ function get_u0(sys, u0map, parammap = nothing; symbolic_u0 = false)
807807
end
808808
defs = mergedefaults(defs, u0map, dvs)
809809
for (k, v) in defs
810-
if Symbolics.isarraysymbolic(k)
810+
if Symbolics.isarraysymbolic(k) &&
811+
Symbolics.shape(unwrap(k)) !== Symbolics.Unknown()
811812
ks = scalarize(k)
812813
length(ks) == length(v) || error("$k has default value $v with unmatched size")
813814
for (kk, vv) in zip(ks, v)

src/systems/index_cache.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,13 @@ function IndexCache(sys::AbstractSystem)
4949
end
5050
end
5151

52-
disc_buffers = Dict{DataType, Set{BasicSymbolic}}()
53-
tunable_buffers = Dict{DataType, Set{BasicSymbolic}}()
54-
constant_buffers = Dict{DataType, Set{BasicSymbolic}}()
55-
dependent_buffers = Dict{DataType, Set{BasicSymbolic}}()
56-
nonnumeric_buffers = Dict{DataType, Set{BasicSymbolic}}()
52+
disc_buffers = Dict{Any, Set{BasicSymbolic}}()
53+
tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
54+
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
55+
dependent_buffers = Dict{Any, Set{BasicSymbolic}}()
56+
nonnumeric_buffers = Dict{Any, Set{BasicSymbolic}}()
5757

58-
function insert_by_type!(buffers::Dict{DataType, Set{BasicSymbolic}}, sym)
58+
function insert_by_type!(buffers::Dict{Any, Set{BasicSymbolic}}, sym)
5959
sym = unwrap(sym)
6060
ctype = concrete_symtype(sym)
6161
buf = get!(buffers, ctype, Set{BasicSymbolic}())
@@ -101,7 +101,7 @@ function IndexCache(sys::AbstractSystem)
101101
if ctype <: Real || ctype <: AbstractArray{<:Real}
102102
if is_discrete_domain(p)
103103
disc_buffers
104-
elseif istunable(p, true) && size(p) !== Symbolics.Unknown()
104+
elseif istunable(p, true) && Symbolics.shape(p) !== Symbolics.Unknown()
105105
tunable_buffers
106106
else
107107
constant_buffers
@@ -113,7 +113,7 @@ function IndexCache(sys::AbstractSystem)
113113
)
114114
end
115115

116-
function get_buffer_sizes_and_idxs(buffers::Dict{DataType, Set{BasicSymbolic}})
116+
function get_buffer_sizes_and_idxs(buffers::Dict{Any, Set{BasicSymbolic}})
117117
idxs = IndexMap()
118118
buffer_sizes = BufferTemplate[]
119119
for (i, (T, buf)) in enumerate(buffers)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,10 @@ function generate_jacobian(
175175
sys::NonlinearSystem, vs = unknowns(sys), ps = full_parameters(sys);
176176
sparse = false, simplify = false, kwargs...)
177177
jac = calculate_jacobian(sys, sparse = sparse, simplify = simplify)
178-
pre = get_preprocess_constants(jac)
178+
pre, sol_states = get_substitutions_and_solved_unknowns(sys)
179179
p = reorder_parameters(sys, ps)
180-
return build_function(jac, vs, p...; postprocess_fbody = pre, kwargs...)
180+
return build_function(
181+
jac, vs, p...; postprocess_fbody = pre, states = sol_states, kwargs...)
181182
end
182183

183184
function calculate_hessian(sys::NonlinearSystem; sparse = false, simplify = false)

test/nonlinearsystem.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ prob = NonlinearProblem(ns, ones(3), [σ => 1.0, ρ => 1.0, β => 1.0])
8787
sol = solve(prob, NewtonRaphson())
8888
@test sol.u[1] sol.u[2]
8989

90+
prob = NonlinearProblem(ns, ones(3), [σ => 1.0, ρ => 1.0, β => 1.0], jac = true)
91+
@test_nowarn solve(prob, NewtonRaphson())
92+
9093
@test_throws ArgumentError NonlinearProblem(ns, ones(4), [σ => 1.0, ρ => 1.0, β => 1.0])
9194

9295
@variables u F s a

0 commit comments

Comments
 (0)