3333 mass_matrix = integrator. f. mass_matrix
3434
3535 # Precalculations
36- γ = dt * d
36+ dtγ = dt * d
37+ neginvdtγ = - inv (dtγ)
3738 dto2 = dt / 2
3839 dto6 = dt / 6
3940
4243 OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
4344 end
4445
45- calc_rosenbrock_differentiation! (integrator, cache, γ, γ , repeat_step, false )
46+ calc_rosenbrock_differentiation! (integrator, cache, dtγ, dtγ , repeat_step, true )
4647
4748 calculate_residuals! (weight, fill! (weight, one (eltype (u))), uprev, uprev,
4849 integrator. opts. abstol, integrator. opts. reltol,
5253 linres = dolinsolve (
5354 integrator, cache. linsolve; A = nothing , b = _vec (linsolve_tmp),
5455 du = integrator. fsalfirst, u = u, p = p, t = t, weight = weight,
55- solverdata = (; gamma = γ ))
56+ solverdata = (; gamma = dtγ ))
5657 else
5758 linres = dolinsolve (integrator, cache. linsolve; A = W, b = _vec (linsolve_tmp),
5859 du = integrator. fsalfirst, u = u, p = p, t = t, weight = weight,
59- solverdata = (; gamma = γ ))
60+ solverdata = (; gamma = dtγ ))
6061 end
6162
6263 vecu = _vec (linres. u)
6364 veck₁ = _vec (k₁)
6465
65- @. . broadcast = false veck₁= - vecu
66+ @. . veck₁ = vecu * neginvdtγ
6667 integrator. stats. nsolve += 1
6768
68- @. . broadcast = false u= uprev + dto2 * k₁
69+ @. . u= uprev + dto2 * k₁
6970 stage_limiter! (u, integrator, p, t + dto2)
7071 f (f₁, u, p, t + dto2)
7172 OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
7677 mul! (_vec (tmp), mass_matrix, _vec (k₁))
7778 end
7879
79- @. . broadcast = false linsolve_tmp= f₁ - tmp
80+ @. . linsolve_tmp = f₁ - tmp
8081
8182 linres = dolinsolve (integrator, linres. cache; b = _vec (linsolve_tmp))
8283 vecu = _vec (linres. u)
83- veck2 = _vec (k₂)
84+ veck₂ = _vec (k₂)
8485
85- @. . broadcast = false veck2 = - vecu
86+ @. . veck₂ = vecu * neginvdtγ + veck₁
8687 integrator. stats. nsolve += 1
8788
88- @. . broadcast= false k₂+= k₁
89- @. . broadcast= false u= uprev + dt * k₂
89+ @. . u = uprev + dt * k₂
9090 stage_limiter! (u, integrator, p, t + dt)
9191 step_limiter! (u, integrator, p, t + dt)
9292
107107 linres = dolinsolve (integrator, linres. cache; b = _vec (linsolve_tmp))
108108 vecu = _vec (linres. u)
109109 veck3 = _vec (k₃)
110- @. . broadcast = false veck3= - vecu
110+ @. . veck3 = vecu * neginvdtγ
111111
112112 integrator. stats. nsolve += 1
113113
127127
128128 if mass_matrix != = I
129129 algvar = reshape (cache. algebraic_vars, size (u))
130- @. . broadcast = false atmp = ifelse (algvar, fsallast, false ) /
131- integrator . opts . abstol
130+ invatol = inv (integrator . opts . abstol)
131+ @. . atmp = ifelse (algvar, fsallast, false ) * invatol
132132 integrator. EEst += integrator. opts. internalnorm (atmp, t)
133133 end
134134 end
145145 mass_matrix = integrator. f. mass_matrix
146146
147147 # Precalculations
148- γ = dt * d
148+ dtγ = dt * d
149+ neginvdtγ = - inv (dtγ)
149150 dto2 = dt / 2
150151 dto6 = dt / 6
151152
154155 OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
155156 end
156157
157- calc_rosenbrock_differentiation! (integrator, cache, γ, γ , repeat_step, false )
158+ calc_rosenbrock_differentiation! (integrator, cache, dtγ, dtγ , repeat_step, true )
158159
159160 calculate_residuals! (weight, fill! (weight, one (eltype (u))), uprev, uprev,
160161 integrator. opts. abstol, integrator. opts. reltol,
@@ -164,17 +165,17 @@ end
164165 linres = dolinsolve (
165166 integrator, cache. linsolve; A = nothing , b = _vec (linsolve_tmp),
166167 du = integrator. fsalfirst, u = u, p = p, t = t, weight = weight,
167- solverdata = (; gamma = γ ))
168+ solverdata = (; gamma = dtγ ))
168169 else
169170 linres = dolinsolve (integrator, cache. linsolve; A = W, b = _vec (linsolve_tmp),
170171 du = integrator. fsalfirst, u = u, p = p, t = t, weight = weight,
171- solverdata = (; gamma = γ ))
172+ solverdata = (; gamma = dtγ ))
172173 end
173174
174175 vecu = _vec (linres. u)
175176 veck₁ = _vec (k₁)
176177
177- @. . broadcast = false veck₁= - vecu
178+ @. . veck₁ = vecu * neginvdtγ
178179 integrator. stats. nsolve += 1
179180
180181 @. . broadcast= false u= uprev + dto2 * k₁
@@ -192,13 +193,12 @@ end
192193
193194 linres = dolinsolve (integrator, linres. cache; b = _vec (linsolve_tmp))
194195 vecu = _vec (linres. u)
195- veck2 = _vec (k₂)
196+ veck₂ = _vec (k₂)
196197
197- @. . broadcast = false veck2 = - vecu
198+ @. . veck₂ = vecu * neginvdtγ + veck₁
198199 integrator. stats. nsolve += 1
199200
200- @. . broadcast= false k₂+= k₁
201- @. . broadcast= false tmp= uprev + dt * k₂
201+ @. . tmp = uprev + dt * k₂
202202 stage_limiter! (u, integrator, p, t + dt)
203203 f (fsallast, tmp, p, t + dt)
204204 OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
216216 vecu = _vec (linres. u)
217217 veck3 = _vec (k₃)
218218
219- @. . broadcast = false veck3= - vecu
219+ @. . veck3 = vecu * neginvdtγ
220220 integrator. stats. nsolve += 1
221221
222222 @. . broadcast= false u= uprev + dto6 * (k₁ + 4 k₂ + k₃)
230230 integrator. EEst = integrator. opts. internalnorm (atmp, t)
231231
232232 if mass_matrix != = I
233- @. . broadcast = false atmp = ifelse (cache . algebraic_vars, fsallast, false ) /
234- integrator . opts . abstol
233+ invatol = inv (integrator . opts . abstol)
234+ @. . atmp = ifelse (cache . algebraic_vars, fsallast, false ) * invatol
235235 integrator. EEst += integrator. opts. internalnorm (atmp, t)
236236 end
237237 end
244244 @unpack c₃₂, d, tf, uf = cache
245245
246246 # Precalculations
247- γ = dt * d
247+ dtγ = dt * d
248+ neginvdtγ = - inv (dtγ)
248249 dto2 = dt / 2
249250 dto6 = dt / 6
250251
@@ -258,22 +259,24 @@ end
258259 # Time derivative
259260 dT = calc_tderivative (integrator, cache)
260261
261- W = calc_W (integrator, cache, γ , repeat_step)
262+ W = calc_W (integrator, cache, dtγ , repeat_step, true )
262263 if ! issuccess_W (W)
263264 integrator. EEst = 2
264265 return nothing
265266 end
266267
267- k₁ = _reshape (W \ - _vec ((integrator. fsalfirst + γ * dT)), axes (uprev))
268+ k₁ = _reshape (W \ _vec ((integrator. fsalfirst + dtγ * dT)), axes (uprev)) * neginvdtγ
268269 integrator. stats. nsolve += 1
269- f₁ = f (uprev + dto2 * k₁, p, t + dto2)
270+ tmp = @. . uprev + dto2 * k₁
271+ f₁ = f (tmp, p, t + dto2)
270272 OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
271273
272274 if mass_matrix === I
273- k₂ = _reshape (W \ - _vec (f₁ - k₁), axes (uprev)) + k₁
275+ k₂ = _reshape (W \ _vec (f₁ - k₁), axes (uprev))
274276 else
275- k₂ = _reshape (W \ - _vec (f₁ - mass_matrix * k₁), axes (uprev)) + k₁
277+ k₂ = _reshape (W \ _vec (f₁ - mass_matrix * k₁), axes (uprev))
276278 end
279+ k₂ = @. . k₂ * neginvdtγ + k₁
277280 integrator. stats. nsolve += 1
278281 u = uprev + dt * k₂
279282
@@ -282,30 +285,28 @@ end
282285 OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
283286
284287 if mass_matrix === I
285- k₃ = _reshape (
286- W \
287- - _vec ((integrator. fsallast - c₃₂ * (k₂ - f₁) -
288- 2 * (k₁ - integrator. fsalfirst) + dt * dT)),
289- axes (uprev))
288+ linsolve_tmp = @. . (integrator. fsallast - c₃₂ * (k₂ - f₁) -
289+ 2 * (k₁ - integrator. fsalfirst) + dt * dT)
290290 else
291- linsolve_tmp = integrator . fsallast - mass_matrix * (c₃₂ * k₂ + 2 * k₁) +
292- c₃₂ * f₁ + 2 * integrator. fsalfirst + dt * dT
293- k₃ = _reshape (W \ - _vec (linsolve_tmp), axes (uprev) )
291+ linsolve_tmp = mass_matrix * (@. . c₃₂ * k₂ + 2 * k₁)
292+ linsolve_tmp = @. . ( integrator. fsallast - linsolve_tmp +
293+ c₃₂ * f₁ + 2 * integrator . fsalfirst + dt * dT )
294294 end
295+ k₃ = _reshape (W \ _vec (linsolve_tmp), axes (uprev)) * neginvdtγ
295296 integrator. stats. nsolve += 1
296297
297298 if u isa Number
298299 utilde = dto6 * f. mass_matrix[1 , 1 ] * (k₁ - 2 * k₂ + k₃)
299300 else
300- utilde = dto6 * f . mass_matrix * (k₁ - 2 * k₂ + k₃)
301+ utilde = f . mass_matrix * ( @. . dto6 * (k₁ - 2 * k₂ + k₃) )
301302 end
302303 atmp = calculate_residuals (utilde, uprev, u, integrator. opts. abstol,
303304 integrator. opts. reltol, integrator. opts. internalnorm, t)
304305 integrator. EEst = integrator. opts. internalnorm (atmp, t)
305306
306307 if mass_matrix != = I
307- atmp = @. ifelse ( ! integrator. differential_vars, integrator . fsallast, false ) ./
308- integrator. opts . abstol
308+ invatol = inv ( integrator. opts . abstol)
309+ atmp = @. ifelse (integrator . differential_vars, false , integrator. fsallast) * invatol
309310 integrator. EEst += integrator. opts. internalnorm (atmp, t)
310311 end
311312 end
321322 @unpack c₃₂, d, tf, uf = cache
322323
323324 # Precalculations
324- γ = dt * d
325+ dtγ = dt * d
326+ neginvdtγ = - inv (dtγ)
325327 dto2 = dt / 2
326328 dto6 = dt / 6
327329
@@ -335,52 +337,52 @@ end
335337 # Time derivative
336338 dT = calc_tderivative (integrator, cache)
337339
338- W = calc_W (integrator, cache, γ , repeat_step)
340+ W = calc_W (integrator, cache, dtγ , repeat_step, true )
339341 if ! issuccess_W (W)
340342 integrator. EEst = 2
341343 return nothing
342344 end
343345
344- k₁ = _reshape (W \ - _vec ((integrator. fsalfirst + γ * dT)), axes (uprev))
346+ k₁ = _reshape (W \ - _vec ((integrator. fsalfirst + dtγ * dT)), axes (uprev))/ dtγ
345347 integrator. stats. nsolve += 1
346- f₁ = f (uprev + dto2 * k₁, p, t + dto2)
348+ tmp = @. . uprev + dto2 * k₁
349+ f₁ = f (tmp, p, t + dto2)
347350 OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
348351
349352 if mass_matrix === I
350- k₂ = _reshape (W \ - _vec (f₁ - k₁), axes (uprev)) + k₁
353+ k₂ = _reshape (W \ _vec (f₁ - k₁), axes (uprev))
351354 else
352355 linsolve_tmp = f₁ - mass_matrix * k₁
353- k₂ = _reshape (W \ - _vec (linsolve_tmp), axes (uprev)) + k₁
356+ k₂ = _reshape (W \ _vec (linsolve_tmp), axes (uprev))
354357 end
358+ k₂ = @. . k₂ * neginvdtγ + k₁
355359
356360 integrator. stats. nsolve += 1
357- tmp = uprev + dt * k₂
361+ tmp = @. . uprev + dt * k₂
358362 integrator. fsallast = f (tmp, p, t + dt)
359363 OrdinaryDiffEqCore. increment_nf! (integrator. stats, 1 )
360364
361365 if mass_matrix === I
362- k₃ = _reshape (
363- W \
364- - _vec ((integrator. fsallast - c₃₂ * (k₂ - f₁) -
365- 2 (k₁ - integrator. fsalfirst) + dt * dT)),
366- axes (uprev))
366+ linsolve_tmp = @. . (integrator. fsallast - c₃₂ * (k₂ - f₁) -
367+ 2 (k₁ - integrator. fsalfirst) + dt * dT)
367368 else
368- linsolve_tmp = integrator . fsallast - mass_matrix * (c₃₂ * k₂ + 2 k₁) + c₃₂ * f₁ +
369- 2 * integrator. fsalfirst + dt * dT
370- k₃ = _reshape (W \ - _vec (linsolve_tmp), axes (uprev) )
369+ linsolve_tmp = mass_matrix * (@. . c₃₂ * k₂ + 2 * k₁)
370+ linsolve_tmp = @. . ( integrator. fsallast - linsolve_tmp +
371+ c₃₂ * f₁ + 2 * integrator . fsalfirst + dt * dT )
371372 end
373+ k₃ = _reshape (W \ _vec (linsolve_tmp), axes (uprev)) * neginvdtγ
372374 integrator. stats. nsolve += 1
373- u = uprev + dto6 * (k₁ + 4 k₂ + k₃)
375+ u = @. . uprev + dto6 * (k₁ + 4 k₂ + k₃)
374376
375377 if integrator. opts. adaptive
376- utilde = dto6 * (k₁ - 2 k₂ + k₃)
378+ utilde = @. . dto6 * (k₁ - 2 k₂ + k₃)
377379 atmp = calculate_residuals (utilde, uprev, u, integrator. opts. abstol,
378380 integrator. opts. reltol, integrator. opts. internalnorm, t)
379381 integrator. EEst = integrator. opts. internalnorm (atmp, t)
380382
381383 if mass_matrix != = I
382- atmp = @. ifelse ( ! integrator. differential_vars, integrator . fsallast, false ) ./
383- integrator. opts . abstol
384+ invatol = inv ( integrator. opts . abstol)
385+ atmp = ifelse (integrator . differential_vars, false , integrator. fsallast) .* invatol
384386 integrator. EEst += integrator. opts. internalnorm (atmp, t)
385387 end
386388 end
0 commit comments