@@ -111,44 +111,45 @@ function generate_function(
111
111
end
112
112
end
113
113
114
- @inline function allequal (x)
115
- length (x) < 2 && return true
116
- e1 = first (x)
117
- i = 2
118
- @inbounds for i= 2 : length (x)
119
- x[i] == e1 || return false
120
- end
121
- return true
122
- end
123
-
124
- function generate_difference_cb (sys:: ODESystem , dvs = states (sys), ps = parameters (sys);
125
- kwargs... )
114
+ function generate_difference_cb (sys:: ODESystem , dvs = states (sys), ps = parameters (sys); kwargs... )
126
115
eqs = equations (sys)
127
116
foreach (check_difference_variables, eqs)
128
117
129
- rhss = [
130
- begin
131
- ind = findfirst (eq -> isdifference (eq. lhs) && isequal (arguments (eq. lhs)[1 ], s), eqs)
132
- ind === nothing ? 0 : eqs[ind]. rhs
133
- end
134
- for s in dvs ]
135
-
118
+ var2eq = Dict (arguments (eq. lhs)[1 ] => eq for eq in eqs if isdifference (eq. lhs))
119
+
136
120
u = map (x-> time_varying_as_func (value (x), sys), dvs)
137
121
p = map (x-> time_varying_as_func (value (x), sys), ps)
138
122
t = get_iv (sys)
139
123
140
- f_oop, f_iip = build_function (rhss, u, p, t; kwargs... )
141
-
142
- f = @RuntimeGeneratedFunction (@__MODULE__ , f_oop)
124
+ body = map (dvs) do v
125
+ eq = get (var2eq, v, nothing )
126
+ eq === nothing && return v
127
+ d = operation (eq. lhs)
128
+ d. update ? eq. rhs : eq. rhs + v
129
+ end
143
130
144
- function cb_affect! (int)
145
- int. u += f (int. u, int. p, int. t)
131
+ pre = get_postprocess_fbody (sys)
132
+ f_oop, f_iip = build_function (body, u, p, t; expression= Val{false }, postprocess_fbody= pre, kwargs... )
133
+
134
+ cb_affect! = let f_oop= f_oop, f_iip= f_iip
135
+ function cb_affect! (integ)
136
+ if DiffEqBase. isinplace (integ. sol. prob)
137
+ tmp, = DiffEqBase. get_tmp_cache (integ)
138
+ f_iip (tmp, integ. u, integ. p, integ. t) # aliasing `integ.u` would be bad.
139
+ copyto! (integ. u, tmp)
140
+ else
141
+ integ. u = f_oop (integ. u, integ. p, integ. t)
142
+ end
143
+ return nothing
144
+ end
146
145
end
147
146
148
- dts = [ operation (eq. lhs). dt for eq in eqs if isdifferenceeq (eq)]
149
- allequal (dts) || error (" All difference variables should have same time steps." )
147
+ getdt (eq) = operation (eq. lhs). dt
148
+ deqs = values (var2eq)
149
+ dt = getdt (first (deqs))
150
+ all (dt == getdt (eq) for eq in deqs) || error (" All difference variables should have same time steps." )
150
151
151
- PeriodicCallback (cb_affect!, first (dts ))
152
+ PeriodicCallback (cb_affect!, first (dt ))
152
153
end
153
154
154
155
function time_varying_as_func (x, sys:: AbstractTimeDependentSystem )
@@ -578,12 +579,11 @@ symbolically calculating numerical enhancements.
578
579
function DiffEqBase. ODEProblem {iip} (sys:: AbstractODESystem ,u0map,tspan,
579
580
parammap= DiffEqBase. NullParameters ();kwargs... ) where iip
580
581
f, u0, p = process_DEProblem (ODEFunction{iip}, sys, u0map, parammap; kwargs... )
581
- if any (isdifferenceeq .( equations (sys) ))
582
- ODEProblem {iip} (f,u0,tspan,p;difference_cb= generate_difference_cb (sys),kwargs... )
582
+ if any (isdifferenceeq, equations (sys))
583
+ ODEProblem {iip} (f,u0,tspan,p;difference_cb= generate_difference_cb (sys;kwargs ... ),kwargs... )
583
584
else
584
585
ODEProblem {iip} (f,u0,tspan,p;kwargs... )
585
586
end
586
-
587
587
end
588
588
589
589
"""
@@ -610,12 +610,11 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem,du0map,u0map,tspan,
610
610
diffvars = collect_differential_variables (sys)
611
611
sts = states (sys)
612
612
differential_vars = map (Base. Fix2 (in, diffvars), sts)
613
- if any (isdifferenceeq .( equations (sys) ))
614
- DAEProblem {iip} (f,du0,u0,tspan,p;difference_cb= generate_difference_cb (sys),differential_vars= differential_vars,kwargs... )
613
+ if any (isdifferenceeq, equations (sys))
614
+ DAEProblem {iip} (f,du0,u0,tspan,p;difference_cb= generate_difference_cb (sys; kwargs ... ),differential_vars= differential_vars,kwargs... )
615
615
else
616
- DAEProblem {iip} (f,du0,u0,tspan,p;differential_vars= differential_vars,kwargs... )
616
+ DAEProblem {iip} (f,du0,u0,tspan,p;differential_vars= differential_vars,kwargs... )
617
617
end
618
-
619
618
end
620
619
621
620
"""
0 commit comments