Skip to content

Commit 7d7df8d

Browse files
authored
Merge pull request #1205 from SciML/myb/opt
Optimization and support DiscreteUpdate operator
2 parents 6980b96 + 9c79d12 commit 7d7df8d

File tree

4 files changed

+44
-41
lines changed

4 files changed

+44
-41
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ Setfield = "0.7"
7373
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0"
7474
StaticArrays = "0.10, 0.11, 0.12, 1.0"
7575
SymbolicUtils = "0.12, 0.13"
76-
Symbolics = "3.0"
76+
Symbolics = "3.1"
7777
UnPack = "0.1, 1.0"
7878
Unitful = "1.1"
7979
julia = "1.2"

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -111,44 +111,45 @@ function generate_function(
111111
end
112112
end
113113

114-
@inline function allequal(x)
115-
length(x) < 2 && return true
116-
e1 = first(x)
117-
i = 2
118-
@inbounds for i=2:length(x)
119-
x[i] == e1 || return false
120-
end
121-
return true
122-
end
123-
124-
function generate_difference_cb(sys::ODESystem, dvs = states(sys), ps = parameters(sys);
125-
kwargs...)
114+
function generate_difference_cb(sys::ODESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
126115
eqs = equations(sys)
127116
foreach(check_difference_variables, eqs)
128117

129-
rhss = [
130-
begin
131-
ind = findfirst(eq -> isdifference(eq.lhs) && isequal(arguments(eq.lhs)[1], s), eqs)
132-
ind === nothing ? 0 : eqs[ind].rhs
133-
end
134-
for s in dvs ]
135-
118+
var2eq = Dict(arguments(eq.lhs)[1] => eq for eq in eqs if isdifference(eq.lhs))
119+
136120
u = map(x->time_varying_as_func(value(x), sys), dvs)
137121
p = map(x->time_varying_as_func(value(x), sys), ps)
138122
t = get_iv(sys)
139123

140-
f_oop, f_iip = build_function(rhss, u, p, t; kwargs...)
141-
142-
f = @RuntimeGeneratedFunction(@__MODULE__, f_oop)
124+
body = map(dvs) do v
125+
eq = get(var2eq, v, nothing)
126+
eq === nothing && return v
127+
d = operation(eq.lhs)
128+
d.update ? eq.rhs : eq.rhs + v
129+
end
143130

144-
function cb_affect!(int)
145-
int.u += f(int.u, int.p, int.t)
131+
pre = get_postprocess_fbody(sys)
132+
f_oop, f_iip = build_function(body, u, p, t; expression=Val{false}, postprocess_fbody=pre, kwargs...)
133+
134+
cb_affect! = let f_oop=f_oop, f_iip=f_iip
135+
function cb_affect!(integ)
136+
if DiffEqBase.isinplace(integ.sol.prob)
137+
tmp, = DiffEqBase.get_tmp_cache(integ)
138+
f_iip(tmp, integ.u, integ.p, integ.t) # aliasing `integ.u` would be bad.
139+
copyto!(integ.u, tmp)
140+
else
141+
integ.u = f_oop(integ.u, integ.p, integ.t)
142+
end
143+
return nothing
144+
end
146145
end
147146

148-
dts = [ operation(eq.lhs).dt for eq in eqs if isdifferenceeq(eq)]
149-
allequal(dts) || error("All difference variables should have same time steps.")
147+
getdt(eq) = operation(eq.lhs).dt
148+
deqs = values(var2eq)
149+
dt = getdt(first(deqs))
150+
all(dt == getdt(eq) for eq in deqs) || error("All difference variables should have same time steps.")
150151

151-
PeriodicCallback(cb_affect!, first(dts))
152+
PeriodicCallback(cb_affect!, first(dt))
152153
end
153154

154155
function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
@@ -578,12 +579,11 @@ symbolically calculating numerical enhancements.
578579
function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
579580
parammap=DiffEqBase.NullParameters();kwargs...) where iip
580581
f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap; kwargs...)
581-
if any(isdifferenceeq.(equations(sys)))
582-
ODEProblem{iip}(f,u0,tspan,p;difference_cb=generate_difference_cb(sys),kwargs...)
582+
if any(isdifferenceeq, equations(sys))
583+
ODEProblem{iip}(f,u0,tspan,p;difference_cb=generate_difference_cb(sys;kwargs...),kwargs...)
583584
else
584585
ODEProblem{iip}(f,u0,tspan,p;kwargs...)
585586
end
586-
587587
end
588588

589589
"""
@@ -610,12 +610,11 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem,du0map,u0map,tspan,
610610
diffvars = collect_differential_variables(sys)
611611
sts = states(sys)
612612
differential_vars = map(Base.Fix2(in, diffvars), sts)
613-
if any(isdifferenceeq.(equations(sys)))
614-
DAEProblem{iip}(f,du0,u0,tspan,p;difference_cb=generate_difference_cb(sys),differential_vars=differential_vars,kwargs...)
613+
if any(isdifferenceeq, equations(sys))
614+
DAEProblem{iip}(f,du0,u0,tspan,p;difference_cb=generate_difference_cb(sys; kwargs...),differential_vars=differential_vars,kwargs...)
615615
else
616-
DAEProblem{iip}(f,du0,u0,tspan,p;differential_vars=differential_vars,kwargs...)
616+
DAEProblem{iip}(f,du0,u0,tspan,p;differential_vars=differential_vars,kwargs...)
617617
end
618-
619618
end
620619

621620
"""

src/systems/diffeqs/odesystem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,8 @@ function build_explicit_observed_function(
239239
# FIXME: this is a rather rough estimate of dependencies.
240240
maxidx = 0
241241
for (i, s) in enumerate(syms)
242-
idx = observed_idx[s]
242+
idx = get(observed_idx, s, nothing)
243+
idx === nothing && throw(ArgumentError("$s is not an observed variable."))
243244
idx > maxidx && (maxidx = idx)
244245
output[i] = obs[idx].rhs
245246
end

test/odesystem.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -404,11 +404,13 @@ eqs = [
404404
@parameters t a b c d
405405
@variables x(t) y(t)
406406
δ = Differential(t)
407-
D = Difference(t; dt=0.1)
407+
Δ = Difference(t; dt=0.1)
408+
U = DiscreteUpdate(t; dt=0.1)
408409
eqs = [
409-
δ(x) ~ a*x - b*x*y,
410-
δ(y) ~ -c*y + d*x*y,
411-
D(x) ~ y
410+
δ(x) ~ a*x - b*x*y
411+
δ(y) ~ -c*y + d*x*y
412+
Δ(x) ~ y
413+
U(y) ~ x + 1
412414
]
413415
@named de = ODESystem(eqs,t,[x,y],[a,b,c,d])
414416
@test generate_difference_cb(de) isa ModelingToolkit.DiffEqCallbacks.DiscreteCallback
@@ -431,7 +433,8 @@ end
431433

432434
prob2 = ODEProblem(lotka,[1.0,1.0],(0.0,1.0),[1.5,1.0,3.0,1.0])
433435
function periodic_difference_affect!(int)
434-
int.u += [int.u[2], 0]
436+
int.u = [int.u[1] + int.u[2], int.u[1] + 1]
437+
return nothing
435438
end
436439

437440
difference_cb = ModelingToolkit.PeriodicCallback(periodic_difference_affect!, 0.1)

0 commit comments

Comments
 (0)