@@ -6,6 +6,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p,
66 if length (k) < 2 || always_calc_begin
77 @unpack tf, uf, d = cache
88 dtγ = dt * d
9+ neginvdtγ = - inv (dtγ)
910 dto2 = dt / 2
1011 tf. u = uprev
1112 if cache. autodiff isa AutoForwardDiff
@@ -17,16 +18,25 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p,
1718 mass_matrix = f. mass_matrix
1819 if uprev isa Number
1920 J = ForwardDiff. derivative (uf, uprev)
20- W = 1 - dtγ * J
21+ W = neginvdtγ .+ J
2122 else
2223 J = ForwardDiff. jacobian (uf, uprev)
23- W = mass_matrix - dtγ * J
24+ if mass_matrix isa UniformScaling
25+ W = neginvdtγ* mass_matrix + J
26+ else
27+ W = @. . neginvdtγ* mass_matrix .+ J
28+ end
2429 end
2530 f₀ = f (uprev, p, t)
26- k₁ = W \ ( @. . f₀ + dtγ * dT)
31+ k₁ = _reshape ( W \ _vec (( f₀ + dtγ * dT)), axes (uprev)) * neginvdtγ
2732 tmp = @. . uprev + dto2 * k₁
2833 f₁ = f (tmp, p, t + dto2)
29- k₂ = (W \ (f₁ - k₁)) + k₁
34+ if mass_matrix === I
35+ k₂ = _reshape (W \ _vec (f₁ - k₁), axes (uprev))
36+ else
37+ k₂ = _reshape (W \ _vec (f₁ - mass_matrix * k₁), axes (uprev))
38+ end
39+ k₂ = @. . k₂ * neginvdtγ + k₁
3040 copyat_or_push! (k, 1 , k₁)
3141 copyat_or_push! (k, 2 , k₂)
3242 end
@@ -46,6 +56,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p,
4656 sizeu = size (u)
4757 mass_matrix = f. mass_matrix
4858 dtγ = dt * d
59+ neginvdtγ = - inv (dtγ)
4960 dto2 = dt / 2
5061
5162 @. . linsolve_tmp= @muladd fsalfirst + dtγ * dT
@@ -61,10 +72,9 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p,
6172
6273 vecu = _vec (linres. u)
6374 veck₁ = _vec (k₁)
75+ @. . veck₁ = vecu * neginvdtγ
6476
65- @. . broadcast= false veck₁= - vecu
66-
67- @. . broadcast= false tmp= uprev + dto2 * k₁
77+ @. . tmp= uprev + dto2 * k₁
6878 f (f₁, tmp, p, t + dto2)
6979
7080 if mass_matrix === I
@@ -73,16 +83,14 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p,
7383 mul! (_vec (tmp), mass_matrix, _vec (k₁))
7484 end
7585
76- @. . broadcast = false linsolve_tmp= f₁ - tmp
86+ @. . linsolve_tmp = f₁ - tmp
7787
7888 linres = dolinsolve (cache, linres. cache; b = _vec (linsolve_tmp),
7989 reltol = cache. reltol)
8090 vecu = _vec (linres. u)
81- veck2 = _vec (k₂)
82-
83- @. . broadcast= false veck2= - vecu
91+ veck₂ = _vec (k₂)
8492
85- @. . broadcast = false k₂ += k ₁
93+ @. . veck₂ = vecu * neginvdtγ + veck ₁
8694
8795 copyat_or_push! (k, 1 , k₁)
8896 copyat_or_push! (k, 2 , k₂)
0 commit comments