Skip to content

Commit 4b55b5b

Browse files
committed
Expanded to include other generate functions for ODESystems & added more testing.
1 parent 8e1d0c9 commit 4b55b5b

File tree

4 files changed

+59
-32
lines changed

4 files changed

+59
-32
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,23 @@ end
8383
function generate_tgrad(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
8484
simplify = false, kwargs...)
8585
tgrad = calculate_tgrad(sys, simplify = simplify)
86-
return build_function(tgrad, dvs, ps, get_iv(sys); kwargs...)
86+
pre = get_preprocess_constants(tgrad)
87+
return build_function(tgrad, dvs, ps, get_iv(sys); postprocess_fbody = pre, kwargs...)
8788
end
8889

8990
function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
9091
simplify = false, sparse = false, kwargs...)
9192
jac = calculate_jacobian(sys; simplify = simplify, sparse = sparse)
92-
return build_function(jac, dvs, ps, get_iv(sys); kwargs...)
93+
pre = get_preprocess_constants(jac)
94+
return build_function(jac, dvs, ps, get_iv(sys); postprocess_fbody = pre, kwargs...)
9395
end
9496

9597
function generate_control_jacobian(sys::AbstractODESystem, dvs = states(sys),
9698
ps = parameters(sys);
9799
simplify = false, sparse = false, kwargs...)
98100
jac = calculate_control_jacobian(sys; simplify = simplify, sparse = sparse)
99-
return build_function(jac, dvs, ps, get_iv(sys); kwargs...)
101+
pre = get_preprocess_constants(jac)
102+
return build_function(jac, dvs, ps, get_iv(sys); postprocess_fbody = pre, kwargs...)
100103
end
101104

102105
function generate_dae_jacobian(sys::AbstractODESystem, dvs = states(sys),
@@ -109,7 +112,8 @@ function generate_dae_jacobian(sys::AbstractODESystem, dvs = states(sys),
109112
dvs = states(sys)
110113
@variables ˍ₋gamma
111114
jac = ˍ₋gamma * jac_du + jac_u
112-
return build_function(jac, derivatives, dvs, ps, ˍ₋gamma, get_iv(sys); kwargs...)
115+
pre = get_preprocess_constants(jac)
116+
return build_function(jac, derivatives, dvs, ps, ˍ₋gamma, get_iv(sys); postprocess_fbody = pre, kwargs...)
113117
end
114118

115119
function generate_function(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
@@ -163,8 +167,10 @@ function generate_difference_cb(sys::ODESystem, dvs = states(sys), ps = paramete
163167
end
164168

165169
pre = get_postprocess_fbody(sys)
170+
cpre = get_preprocess_constants(body)
171+
pre2 = x -> pre(cpre(x))
166172
f_oop, f_iip = build_function(body, u, p, t; expression = Val{false},
167-
postprocess_fbody = pre, kwargs...)
173+
postprocess_fbody = pre2, kwargs...)
168174

169175
cb_affect! = let f_oop = f_oop, f_iip = f_iip
170176
function cb_affect!(integ)

src/systems/diffeqs/odesystem.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -238,14 +238,6 @@ function ODESystem(eqs, iv = nothing; kwargs...)
238238
collect(Iterators.flatten((diffvars, algevars))), ps; kwargs...)
239239
end
240240

241-
function collect_constants(eqs) #Does this need to be different for other system types?
242-
constants = Set()
243-
for eq in eqs
244-
collect_constants!(constants, eq.lhs)
245-
collect_constants!(constants, eq.rhs)
246-
end
247-
return collect(constants)
248-
end
249241

250242
# NOTE: equality does not check cached Jacobian
251243
function Base.:(==)(sys1::ODESystem, sys2::ODESystem)

src/utils.jl

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -480,17 +480,6 @@ function collect_vars!(states, parameters, expr, iv)
480480
return nothing
481481
end
482482

483-
function collect_constants!(constants, expr)
484-
if expr isa Sym
485-
collect_constant!(constants, expr)
486-
else
487-
for var in vars(expr)
488-
collect_constant!(constants, var)
489-
end
490-
end
491-
return nothing
492-
end
493-
494483
function collect_vars_difference!(states, parameters, expr, iv)
495484
if expr isa Sym
496485
collect_var!(states, parameters, expr, iv)
@@ -515,10 +504,46 @@ function collect_var!(states, parameters, var, iv)
515504
return nothing
516505
end
517506

507+
function collect_constants(eqs::Vector{Equation}) #For get_substitutions_and_solved_states
508+
constants = []
509+
for eq in eqs
510+
collect_constants!(constants, eq.lhs)
511+
collect_constants!(constants, eq.rhs)
512+
end
513+
return constants
514+
end
515+
516+
function collect_constants(eqs::AbstractArray{T}) where T # For generate_tgrad / generate_jacobian / generate_difference_cb
517+
constants = T[]
518+
for eq in eqs
519+
collect_constants!(constants, unwrap(eq))
520+
end
521+
return constants
522+
end
523+
518524
function collect_constant!(constants, var)
519525
if isconstant(var)
520-
push!(constants,var)
526+
push!(constants, var)
521527
end
528+
return nothing
529+
end
530+
531+
function collect_constants!(constants, expr)
532+
if expr isa Sym
533+
collect_constant!(constants, expr)
534+
else
535+
for var in vars(expr)
536+
collect_constant!(constants, var)
537+
end
538+
end
539+
return nothing
540+
end
541+
542+
function get_preprocess_constants(eqs)
543+
cs = collect_constants(eqs)
544+
pre = ex -> Let(Assignment[Assignment(x, getdefault(x)) for x in cs],
545+
ex, false)
546+
return pre
522547
end
523548

524549
function get_postprocess_fbody(sys)
@@ -561,6 +586,9 @@ end
561586
function get_substitutions_and_solved_states(sys; no_postprocess = false)
562587
#Inject substitutions for constants => values
563588
cs = collect_constants([sys.eqs; sys.observed]) #ctrls? what else?
589+
if !empty_substitutions(sys)
590+
cs = [cs; collect_constants(sys.substitutions.subs)]
591+
end
564592
# Swap constants for their values
565593
cmap = map(x -> x ~ getdefault(x), cs)
566594

test/odesystem.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@ using ModelingToolkit: value
88

99
# Define some variables
1010
@parameters t σ ρ β
11+
@constants κ = 1
1112
@variables x(t) y(t) z(t)
1213
D = Differential(t)
1314

1415
# Define a differential equation
1516
eqs = [D(x) ~ σ * (y - x),
1617
D(y) ~ x *- z) - y,
17-
D(z) ~ x * y - β * z]
18+
D(z) ~ x * y - β * z * κ]
1819

1920
ModelingToolkit.toexpr.(eqs)[1]
2021
@named de = ODESystem(eqs; defaults = Dict(x => 1))
@@ -71,7 +72,7 @@ end
7172

7273
eqs = [D(x) ~ σ * (y - x),
7374
D(y) ~ x *- z) - y * t,
74-
D(z) ~ x * y - β * z]
75+
D(z) ~ x * y - β * z * κ]
7576
@named de = ODESystem(eqs)
7677
ModelingToolkit.calculate_tgrad(de)
7778

@@ -87,7 +88,7 @@ tgrad_iip(du, u, p, t)
8788
@parameters σ′(t - 1)
8889
eqs = [D(x) ~ σ′ * (y - x),
8990
D(y) ~ x *- z) - y,
90-
D(z) ~ x * y - β * z]
91+
D(z) ~ x * y - β * z * κ]
9192
@named de = ODESystem(eqs)
9293
test_diffeq_inference("global iv-varying", de, t, (x, y, z), (σ′, ρ, β))
9394

@@ -99,7 +100,7 @@ f(du, [1.0, 2.0, 3.0], [x -> x + 7, 2, 3], 5.0)
99100
@parameters σ(..)
100101
eqs = [D(x) ~ σ(t - 1) * (y - x),
101102
D(y) ~ x *- z) - y,
102-
D(z) ~ x * y - β * z]
103+
D(z) ~ x * y - β * z * κ]
103104
@named de = ODESystem(eqs)
104105
test_diffeq_inference("single internal iv-varying", de, t, (x, y, z), (σ(t - 1), ρ, β))
105106
f = eval(generate_function(de, [x, y, z], [σ, ρ, β])[2])
@@ -146,7 +147,7 @@ ODEFunction(de1, [uˍtt, xˍt, uˍt, u, x], [])(du, ones(5), nothing, 0.1)
146147
a = y - x
147148
eqs = [D(x) ~ σ * a,
148149
D(y) ~ x *- z) - y,
149-
D(z) ~ x * y - β * z]
150+
D(z) ~ x * y - β * z * κ]
150151
@named de = ODESystem(eqs)
151152
generate_function(de, [x, y, z], [σ, ρ, β])
152153
jac = calculate_jacobian(de)
@@ -201,7 +202,7 @@ D = Differential(t)
201202
# reorder the system just to be a little spicier
202203
eqs = [D(y₁) ~ -k₁ * y₁ + k₃ * y₂ * y₃,
203204
0 ~ y₁ + y₂ + y₃ - 1,
204-
D(y₂) ~ k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃]
205+
D(y₂) ~ k₁ * y₁ - k₂ * y₂^2 - k₃ * y₂ * y₃ * κ]
205206
@named sys = ODESystem(eqs, defaults = [k₁ => 100, k₂ => 3e7, y₁ => 1.0])
206207
u0 = Pair[]
207208
push!(u0, y₂ => 0.0)
@@ -222,7 +223,7 @@ for p in [prob1, prob14]
222223
@test Set(Num.(states(sys)) .=> p.u0) == Set([y₁ => 1, y₂ => 0, y₃ => 0])
223224
end
224225
prob2 = ODEProblem(sys, u0, tspan, p, jac = true)
225-
prob3 = ODEProblem(sys, u0, tspan, p, jac = true, sparse = true)
226+
prob3 = ODEProblem(sys, u0, tspan, p, jac = true, sparse = true) #SparseMatrixCSC need to handle
226227
@test prob3.f.jac_prototype isa SparseMatrixCSC
227228
prob3 = ODEProblem(sys, u0, tspan, p, jac = true, sparsity = true)
228229
@test prob3.f.sparsity isa SparseMatrixCSC

0 commit comments

Comments
 (0)