@@ -111,28 +111,17 @@ function generate_function(
111
111
end
112
112
end
113
113
114
- @inline function allequal (x)
115
- length (x) < 2 && return true
116
- e1 = first (x)
117
- i = 2
118
- @inbounds for i= 2 : length (x)
119
- x[i] == e1 || return false
120
- end
121
- return true
122
- end
123
-
124
- function generate_difference_cb (sys:: ODESystem , dvs = states (sys), ps = parameters (sys);
125
- kwargs... )
114
+ function generate_difference_cb (sys:: ODESystem , dvs = states (sys), ps = parameters (sys); kwargs... )
126
115
eqs = equations (sys)
127
116
foreach (check_difference_variables, eqs)
128
117
129
- rhss = [
118
+ rhss = [
130
119
begin
131
120
ind = findfirst (eq -> isdifference (eq. lhs) && isequal (arguments (eq. lhs)[1 ], s), eqs)
132
121
ind === nothing ? 0 : eqs[ind]. rhs
133
122
end
134
- for s in dvs ]
135
-
123
+ for s in dvs]
124
+
136
125
u = map (x-> time_varying_as_func (value (x), sys), dvs)
137
126
p = map (x-> time_varying_as_func (value (x), sys), ps)
138
127
t = get_iv (sys)
@@ -141,12 +130,12 @@ function generate_difference_cb(sys::ODESystem, dvs = states(sys), ps = paramete
141
130
142
131
f = @RuntimeGeneratedFunction (@__MODULE__ , f_oop)
143
132
144
- function cb_affect! (int)
145
- int. u += f (int. u, int. p, int. t)
133
+ function cb_affect! (int)
134
+ int. u += f (int. u, int. p, int. t)
146
135
end
147
136
148
- dts = [ operation (eq. lhs). dt for eq in eqs if isdifferenceeq (eq)]
149
- allequal ( dts) || error (" All difference variables should have same time steps." )
137
+ dts = [operation (eq. lhs). dt for eq in eqs if isdifferenceeq (eq)]
138
+ all (dt == dts[ 1 ] for dt in dts) || error (" All difference variables should have same time steps." )
150
139
151
140
PeriodicCallback (cb_affect!, first (dts))
152
141
end
0 commit comments