Skip to content

Commit 3fbf453

Browse files
committed
Optimize discrete callbacks
1 parent c5f03bc commit 3fbf453

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -115,29 +115,40 @@ function generate_difference_cb(sys::ODESystem, dvs = states(sys), ps = paramete
115115
eqs = equations(sys)
116116
foreach(check_difference_variables, eqs)
117117

118-
rhss = [
119-
begin
120-
ind = findfirst(eq -> isdifference(eq.lhs) && isequal(arguments(eq.lhs)[1], s), eqs)
121-
ind === nothing ? 0 : eqs[ind].rhs
122-
end
123-
for s in dvs]
118+
var2eq = Dict(arguments(eq.lhs)[1] => eq for eq in eqs if isdifference(eq.lhs))
124119

125120
u = map(x->time_varying_as_func(value(x), sys), dvs)
126121
p = map(x->time_varying_as_func(value(x), sys), ps)
127122
t = get_iv(sys)
128123

129-
f_oop, f_iip = build_function(rhss, u, p, t; kwargs...)
124+
body = map(dvs) do v
125+
eq = get(var2eq, v, nothing)
126+
eq === nothing && return v
127+
d = operation(eq.lhs)
128+
d.update ? eq.rhs : eq.rhs + v
129+
end
130130

131-
f = @RuntimeGeneratedFunction(@__MODULE__, f_oop)
131+
f_oop, f_iip = build_function(body, u, p, t; expression=Val{false}, kwargs...)
132132

133-
function cb_affect!(int)
134-
int.u += f(int.u, int.p, int.t)
133+
cb_affect! = let f_oop=f_oop, f_iip=f_iip
134+
function cb_affect!(integ)
135+
if DiffEqBase.isinplace(integ.sol.prob)
136+
tmp, = DiffEqBase.get_tmp_cache(integ)
137+
f_iip(tmp, integ.u, integ.p, integ.t) # aliasing `integ.u` would be bad.
138+
copyto!(integ.u, tmp)
139+
else
140+
integ.u = f_oop(integ.u, integ.p, integ.t)
141+
end
142+
return nothing
143+
end
135144
end
136145

137-
dts = [operation(eq.lhs).dt for eq in eqs if isdifferenceeq(eq)]
138-
all(dt == dts[1] for dt in dts) || error("All difference variables should have same time steps.")
146+
getdt(eq) = operation(eq.lhs).dt
147+
deqs = values(var2eq)
148+
dt = getdt(first(deqs))
149+
all(dt == getdt(eq) for eq in deqs) || error("All difference variables should have same time steps.")
139150

140-
PeriodicCallback(cb_affect!, first(dts))
151+
PeriodicCallback(cb_affect!, first(dt))
141152
end
142153

143154
function time_varying_as_func(x, sys::AbstractTimeDependentSystem)

0 commit comments

Comments
 (0)