Skip to content

Commit 66df11f

Browse files
committed
Merge branch 'master' into myb/arrayvar
2 parents d12cdf3 + 549885d commit 66df11f

File tree

10 files changed

+239
-108
lines changed

10 files changed

+239
-108
lines changed

.github/workflows/Downstream.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ jobs:
2121
- {user: SciML, repo: CellMLToolkit.jl, group: All}
2222
- {user: SciML, repo: NeuralPDE.jl, group: NNPDE}
2323
- {user: SciML, repo: DataDrivenDiffEq.jl, group: Standard}
24-
24+
- {user: SciML, repo: StructuralIdentifiability.jl, group: All}
25+
2526
steps:
2627
- uses: actions/checkout@v2
2728
- uses: julia-actions/setup-julia@v1

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
name = "ModelingToolkit"
22
uuid = "961ee093-0014-501f-94e3-6117800e7a78"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "5.23.0"
4+
version = "5.24.0"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1010
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
1111
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
12+
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
1213
DiffEqJump = "c894b116-72e5-5b58-be3c-e6d8d4ac2b12"
1314
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
1415
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -48,6 +49,7 @@ ArrayInterface = "2.8, 3.0"
4849
ConstructionBase = "1"
4950
DataStructures = "0.17, 0.18"
5051
DiffEqBase = "6.54.0"
52+
DiffEqCallbacks = "2.16"
5153
DiffEqJump = "6.7.5"
5254
DiffRules = "0.1, 1.0"
5355
Distributions = "0.23, 0.24, 0.25"

src/ModelingToolkit.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using DataStructures
1717
using SpecialFunctions, NaNMath
1818
using RuntimeGeneratedFunctions
1919
using Base.Threads
20+
using DiffEqCallbacks
2021
import MacroTools: splitdef, combinedef, postwalk, striplines
2122
import Libdl
2223
using DocStringExtensions
@@ -182,6 +183,7 @@ export calculate_hessian, generate_hessian
182183
export calculate_massmatrix, generate_diffusion_function
183184
export stochastic_integral_transform
184185
export initialize_system_structure
186+
export generate_difference_cb
185187

186188
export BipartiteGraph, equation_dependencies, variable_dependencies
187189
export eqeq_dependencies, varvar_dependencies

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ function generate_function(
9090
#obsvars = map(eq->eq.lhs, observed(sys))
9191
#fulldvs = [dvs; obsvars]
9292

93-
eqs = equations(sys)
93+
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
9494
foreach(check_derivative_variables, eqs)
9595
# substitute x(t) by just x
9696
rhss = implicit_dae ? [_iszero(eq.lhs) ? eq.rhs : eq.rhs - eq.lhs for eq in eqs] :
@@ -109,6 +109,46 @@ function generate_function(
109109
end
110110
end
111111

112+
@inline function allequal(x)
113+
length(x) < 2 && return true
114+
e1 = first(x)
115+
i = 2
116+
@inbounds for i=2:length(x)
117+
x[i] == e1 || return false
118+
end
119+
return true
120+
end
121+
122+
function generate_difference_cb(sys::ODESystem, dvs = states(sys), ps = parameters(sys);
123+
kwargs...)
124+
eqs = equations(sys)
125+
foreach(check_difference_variables, eqs)
126+
127+
rhss = [
128+
begin
129+
ind = findfirst(eq -> isdifference(eq.lhs) && isequal(arguments(eq.lhs)[1], s), eqs)
130+
ind === nothing ? 0 : eqs[ind].rhs
131+
end
132+
for s in dvs ]
133+
134+
u = map(x->time_varying_as_func(value(x), sys), dvs)
135+
p = map(x->time_varying_as_func(value(x), sys), ps)
136+
t = get_iv(sys)
137+
138+
f_oop, f_iip = build_function(rhss, u, p, t; kwargs...)
139+
140+
f = @RuntimeGeneratedFunction(@__MODULE__, f_oop)
141+
142+
function cb_affect!(int)
143+
int.u += f(int.u, int.p, int.t)
144+
end
145+
146+
dts = [ operation(eq.lhs).dt for eq in eqs if isdifferenceeq(eq)]
147+
allequal(dts) || error("All difference variables should have same time steps.")
148+
149+
PeriodicCallback(cb_affect!, first(dts))
150+
end
151+
112152
function time_varying_as_func(x, sys)
113153
# if something is not x(t) (the current state)
114154
# but is `x(t-1)` or something like that, pass in `x` as a callable function rather
@@ -124,7 +164,7 @@ function time_varying_as_func(x, sys)
124164
end
125165

126166
function calculate_massmatrix(sys::AbstractODESystem; simplify=false)
127-
eqs = equations(sys)
167+
eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
128168
dvs = states(sys)
129169
M = zeros(length(eqs),length(eqs))
130170
state2idx = Dict(s => i for (i, s) in enumerate(dvs))
@@ -536,7 +576,12 @@ symbolically calculating numerical enhancements.
536576
function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
537577
parammap=DiffEqBase.NullParameters();kwargs...) where iip
538578
f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap; kwargs...)
539-
ODEProblem{iip}(f,u0,tspan,p;kwargs...)
579+
if any(isdifferenceeq.(equations(sys)))
580+
ODEProblem{iip}(f,u0,tspan,p;difference_cb=generate_difference_cb(sys),kwargs...)
581+
else
582+
ODEProblem{iip}(f,u0,tspan,p;kwargs...)
583+
end
584+
540585
end
541586

542587
"""
@@ -563,7 +608,12 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem,du0map,u0map,tspan,
563608
diffvars = collect_differential_variables(sys)
564609
sts = states(sys)
565610
differential_vars = map(Base.Fix2(in, diffvars), sts)
566-
DAEProblem{iip}(f,du0,u0,tspan,p;differential_vars=differential_vars,kwargs...)
611+
if any(isdifferenceeq.(equations(sys)))
612+
DAEProblem{iip}(f,du0,u0,tspan,p;difference_cb=generate_difference_cb(sys),differential_vars=differential_vars,kwargs...)
613+
else
614+
DAEProblem{iip}(f,du0,u0,tspan,p;differential_vars=differential_vars,kwargs...)
615+
end
616+
567617
end
568618

569619
"""
@@ -713,6 +763,3 @@ end
713763
function SteadyStateProblemExpr(sys::AbstractODESystem, args...; kwargs...)
714764
SteadyStateProblemExpr{true}(sys, args...; kwargs...)
715765
end
716-
717-
isdifferential(expr) = istree(expr) && operation(expr) isa Differential
718-
isdiffeq(eq) = isdifferential(eq.lhs)

src/systems/diffeqs/odesystem.jl

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -137,42 +137,6 @@ function ODESystem(
137137
ODESystem(deqs, iv′, dvs′, ps′, var_to_name, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
138138
end
139139

140-
vars(x::Sym) = Set([x])
141-
vars(exprs::Symbolic) = vars([exprs])
142-
vars(exprs) = foldl(vars!, exprs; init = Set())
143-
vars!(vars, eq::Equation) = (vars!(vars, eq.lhs); vars!(vars, eq.rhs); vars)
144-
function vars!(vars, O)
145-
if isa(O, Sym)
146-
return push!(vars, O)
147-
end
148-
!istree(O) && return vars
149-
150-
operation(O) isa Differential && return push!(vars, O)
151-
152-
if operation(O) === (getindex) &&
153-
first(arguments(O)) isa Symbolic
154-
155-
return push!(vars, O)
156-
end
157-
158-
symtype(operation(O)) <: FnType && push!(vars, O)
159-
for arg in arguments(O)
160-
vars!(vars, arg)
161-
end
162-
163-
return vars
164-
end
165-
166-
find_derivatives!(vars, expr::Equation, f=identity) = (find_derivatives!(vars, expr.lhs, f); find_derivatives!(vars, expr.rhs, f); vars)
167-
function find_derivatives!(vars, expr, f)
168-
!istree(O) && return vars
169-
operation(O) isa Differential && push!(vars, f(O))
170-
for arg in arguments(O)
171-
vars!(vars, arg)
172-
end
173-
return vars
174-
end
175-
176140
function ODESystem(eqs, iv=nothing; kwargs...)
177141
eqs = collect(eqs)
178142
# NOTE: this assumes that the order of algebric equations doesn't matter
@@ -211,30 +175,6 @@ function ODESystem(eqs, iv=nothing; kwargs...)
211175
return ODESystem(append!(diffeq, algeeq), iv, vcat(collect(diffvars), collect(algevars)), ps; kwargs...)
212176
end
213177

214-
function collect_vars!(states, parameters, expr, iv)
215-
if expr isa Sym
216-
collect_var!(states, parameters, expr, iv)
217-
else
218-
for var in vars(expr)
219-
if istree(var) && operation(var) isa Differential
220-
var, _ = var_from_nested_derivative(var)
221-
end
222-
collect_var!(states, parameters, var, iv)
223-
end
224-
end
225-
return nothing
226-
end
227-
228-
function collect_var!(states, parameters, var, iv)
229-
isequal(var, iv) && return nothing
230-
if isparameter(var) || (istree(var) && isparameter(operation(var)))
231-
push!(parameters, var)
232-
else
233-
push!(states, var)
234-
end
235-
return nothing
236-
end
237-
238178
# NOTE: equality does not check cached Jacobian
239179
function Base.:(==)(sys1::ODESystem, sys2::ODESystem)
240180
iv1 = independent_variable(sys1)
@@ -322,21 +262,6 @@ function _eq_unordered(a, b)
322262
return true
323263
end
324264

325-
function collect_differential_variables(sys::ODESystem)
326-
eqs = equations(sys)
327-
vars = Set()
328-
diffvars = Set()
329-
for eq in eqs
330-
vars!(vars, eq)
331-
for v in vars
332-
isdifferential(v) || continue
333-
push!(diffvars, arguments(v)[1])
334-
end
335-
empty!(vars)
336-
end
337-
return diffvars
338-
end
339-
340265
# We have a stand-alone function to convert a `NonlinearSystem` or `ODESystem`
341266
# to an `ODESystem` to connect systems, and we later can reply on
342267
# `structural_simplify` to convert `ODESystem`s to `NonlinearSystem`s.

src/systems/discrete_system/discrete_system.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,6 @@ function DiffEqBase.DiscreteProblem(sys::DiscreteSystem,u0map,tspan,
125125
DiscreteProblem(f,u0,tspan,p;kwargs...)
126126
end
127127

128-
isdifference(expr) = istree(expr) && operation(expr) isa Difference
129-
isdifferenceeq(eq) = isdifference(eq.lhs)
130-
131128
check_difference_variables(eq) = check_operator_variables(eq, Difference)
132129

133130
function generate_function(

src/systems/reaction/reactionsystem.jl

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -149,44 +149,38 @@ struct ReactionSystem <: AbstractSystem
149149
name::Symbol
150150
"""systems: The internal systems"""
151151
systems::Vector
152+
"""
153+
defaults: The default values to use when initial conditions and/or
154+
parameters are not supplied in `ODEProblem`.
155+
"""
156+
defaults::Dict
152157

153-
function ReactionSystem(eqs, iv, states, ps, observed, name, systems)
158+
function ReactionSystem(eqs, iv, states, ps, observed, name, systems, defaults)
154159
iv′ = value(iv)
155160
states′ = value.(states)
156161
ps′ = value.(ps)
157162
check_variables(states′, iv′)
158163
check_parameters(ps′, iv′)
159-
new(collect(eqs), iv′, states′, ps′, observed, name, systems)
164+
new(collect(eqs), iv′, states′, ps′, observed, name, systems, defaults)
160165
end
161166
end
162167

163168
function ReactionSystem(eqs, iv, species, params;
164169
observed = [],
165170
systems = [],
166-
name = gensym(:ReactionSystem))
171+
name = gensym(:ReactionSystem),
172+
default_u0=Dict(),
173+
default_p=Dict(),
174+
defaults=_merge(Dict(default_u0), Dict(default_p)))
167175

168176
#isempty(species) && error("ReactionSystems require at least one species.")
169-
ReactionSystem(eqs, iv, species, params, observed, name, systems)
177+
ReactionSystem(eqs, iv, species, params, observed, name, systems, defaults)
170178
end
171179

172180
function ReactionSystem(iv; kwargs...)
173181
ReactionSystem(Reaction[], iv, [], []; kwargs...)
174182
end
175183

176-
function equations(sys::ModelingToolkit.ReactionSystem)
177-
eqs = get_eqs(sys)
178-
systems = get_systems(sys)
179-
if isempty(systems)
180-
return eqs
181-
else
182-
eqs = [eqs;
183-
reduce(vcat,
184-
namespace_equations.(get_systems(sys));
185-
init=[])]
186-
return eqs
187-
end
188-
end
189-
190184
"""
191185
oderatelaw(rx; combinatoric_ratelaw=true)
192186
@@ -419,7 +413,7 @@ function Base.convert(::Type{<:ODESystem}, rs::ReactionSystem;
419413
name=nameof(rs), combinatoric_ratelaws=true, include_zero_odes=true, kwargs...)
420414
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws, include_zero_odes=include_zero_odes)
421415
systems = map(sys -> (sys isa ODESystem) ? sys : convert(ODESystem, sys), get_systems(rs))
422-
ODESystem(eqs, get_iv(rs), get_states(rs), get_ps(rs); name=name, systems=systems, kwargs...)
416+
ODESystem(eqs, get_iv(rs), get_states(rs), get_ps(rs); name=name, systems=systems, defaults=get_defaults(rs), kwargs...)
423417
end
424418

425419
"""
@@ -439,7 +433,7 @@ function Base.convert(::Type{<:NonlinearSystem},rs::ReactionSystem;
439433
name=nameof(rs), combinatoric_ratelaws=true, include_zero_odes=true, kwargs...)
440434
eqs = assemble_drift(rs; combinatoric_ratelaws=combinatoric_ratelaws, as_odes=false, include_zero_odes=include_zero_odes)
441435
systems = convert.(NonlinearSystem, get_systems(rs))
442-
NonlinearSystem(eqs, get_states(rs), get_ps(rs); name=name, systems=systems, kwargs...)
436+
NonlinearSystem(eqs, get_states(rs), get_ps(rs); name=name, systems=systems, defaults=get_defaults(rs), kwargs...)
443437
end
444438

445439
"""
@@ -487,6 +481,7 @@ function Base.convert(::Type{<:SDESystem}, rs::ReactionSystem;
487481
(noise_scaling===nothing) ? get_ps(rs) : union(get_ps(rs), toparam(noise_scaling));
488482
name=name,
489483
systems=systems,
484+
defaults=get_defaults(rs),
490485
kwargs...)
491486
end
492487

@@ -507,7 +502,7 @@ function Base.convert(::Type{<:JumpSystem},rs::ReactionSystem;
507502
name=nameof(rs), combinatoric_ratelaws=true, kwargs...)
508503
eqs = assemble_jumps(rs; combinatoric_ratelaws=combinatoric_ratelaws)
509504
systems = convert.(JumpSystem, get_systems(rs))
510-
JumpSystem(eqs, get_iv(rs), get_states(rs), get_ps(rs); name=name, systems=systems, kwargs...)
505+
JumpSystem(eqs, get_iv(rs), get_states(rs), get_ps(rs); name=name, systems=systems, defaults=get_defaults(rs), kwargs...)
511506
end
512507

513508

0 commit comments

Comments
 (0)