@@ -115,29 +115,40 @@ function generate_difference_cb(sys::ODESystem, dvs = states(sys), ps = paramete
115
115
eqs = equations (sys)
116
116
foreach (check_difference_variables, eqs)
117
117
118
- rhss = [
119
- begin
120
- ind = findfirst (eq -> isdifference (eq. lhs) && isequal (arguments (eq. lhs)[1 ], s), eqs)
121
- ind === nothing ? 0 : eqs[ind]. rhs
122
- end
123
- for s in dvs]
118
+ var2eq = Dict (arguments (eq. lhs)[1 ] => eq for eq in eqs if isdifference (eq. lhs))
124
119
125
120
u = map (x-> time_varying_as_func (value (x), sys), dvs)
126
121
p = map (x-> time_varying_as_func (value (x), sys), ps)
127
122
t = get_iv (sys)
128
123
129
- f_oop, f_iip = build_function (rhss, u, p, t; kwargs... )
124
+ body = map (dvs) do v
125
+ eq = get (var2eq, v, nothing )
126
+ eq === nothing && return v
127
+ d = operation (eq. lhs)
128
+ d. update ? eq. rhs : eq. rhs + v
129
+ end
130
130
131
- f = @RuntimeGeneratedFunction ( @__MODULE__ , f_oop )
131
+ f_oop, f_iip = build_function (body, u, p, t; expression = Val{ false }, kwargs ... )
132
132
133
- function cb_affect! (int)
134
- int. u += f (int. u, int. p, int. t)
133
+ cb_affect! = let f_oop= f_oop, f_iip= f_iip
134
+ function cb_affect! (integ)
135
+ if DiffEqBase. isinplace (integ. sol. prob)
136
+ tmp, = DiffEqBase. get_tmp_cache (integ)
137
+ f_iip (tmp, integ. u, integ. p, integ. t) # aliasing `integ.u` would be bad.
138
+ copyto! (integ. u, tmp)
139
+ else
140
+ integ. u = f_oop (integ. u, integ. p, integ. t)
141
+ end
142
+ return nothing
143
+ end
135
144
end
136
145
137
- dts = [operation (eq. lhs). dt for eq in eqs if isdifferenceeq (eq)]
138
- all (dt == dts[1 ] for dt in dts) || error (" All difference variables should have same time steps." )
146
+ getdt (eq) = operation (eq. lhs). dt
147
+ deqs = values (var2eq)
148
+ dt = getdt (first (deqs))
149
+ all (dt == getdt (eq) for eq in deqs) || error (" All difference variables should have same time steps." )
139
150
140
- PeriodicCallback (cb_affect!, first (dts ))
151
+ PeriodicCallback (cb_affect!, first (dt ))
141
152
end
142
153
143
154
function time_varying_as_func (x, sys:: AbstractTimeDependentSystem )
0 commit comments