Skip to content

Commit 1938d52

Browse files
author
oscarddssmith
committed
fix addsteps
1 parent adbb347 commit 1938d52

File tree

2 files changed

+28
-22
lines changed

2 files changed

+28
-22
lines changed

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ end
6868
@.. veck₁ = vecu * neginvdtγ
6969
integrator.stats.nsolve += 1
7070

71-
@.. broadcast=false u=uprev + dto2 * k₁
71+
@.. u=uprev + dto2 * k₁
7272
stage_limiter!(u, integrator, p, t + dto2)
7373
f(f₁, u, p, t + dto2)
7474
integrator.stats.nf += 1
@@ -79,17 +79,16 @@ end
7979
mul!(_vec(tmp), mass_matrix, _vec(k₁))
8080
end
8181

82-
@.. broadcast=false linsolve_tmp=f₁ - tmp
82+
@.. linsolve_tmp = f₁ - tmp
8383

8484
linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp))
8585
vecu = _vec(linres.u)
86-
veck2 = _vec(k₂)
86+
veck₂ = _vec(k₂)
8787

88-
@.. veck2 = vecu * neginvdtγ
88+
@.. veck₂ = vecu * neginvdtγ + veck₁
8989
integrator.stats.nsolve += 1
9090

91-
@.. broadcast=false k₂+=k₁
92-
@.. broadcast=false u=uprev + dt * k₂
91+
@.. u = uprev + dt * k₂
9392
stage_limiter!(u, integrator, p, t + dt)
9493
step_limiter!(u, integrator, p, t + dt)
9594

@@ -196,13 +195,12 @@ end
196195

197196
linres = dolinsolve(integrator, linres.cache; b = _vec(linsolve_tmp))
198197
vecu = _vec(linres.u)
199-
veck2 = _vec(k₂)
198+
veck₂ = _vec(k₂)
200199

201-
@.. veck2 = vecu * neginvdtγ
200+
@.. veck₂ = vecu * neginvdtγ + veck₁
202201
integrator.stats.nsolve += 1
203202

204-
@.. broadcast=false k₂+=k₁
205-
@.. broadcast=false tmp=uprev + dt * k₂
203+
@.. tmp = uprev + dt * k₂
206204
stage_limiter!(u, integrator, p, t + dt)
207205
f(fsallast, tmp, p, t + dt)
208206
integrator.stats.nf += 1

lib/OrdinaryDiffEqRosenbrock/src/stiff_addsteps.jl

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p,
66
if length(k) < 2 || always_calc_begin
77
@unpack tf, uf, d = cache
88
dtγ = dt * d
9+
neginvdtγ = -inv(dtγ)
910
dto2 = dt / 2
1011
tf.u = uprev
1112
if cache.autodiff isa AutoForwardDiff
@@ -17,16 +18,25 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p,
1718
mass_matrix = f.mass_matrix
1819
if uprev isa Number
1920
J = ForwardDiff.derivative(uf, uprev)
20-
W = 1 - dtγ * J
21+
W = neginvdtγ .+ J
2122
else
2223
J = ForwardDiff.jacobian(uf, uprev)
23-
W = mass_matrix - dtγ * J
24+
if mass_matrix isa UniformScaling
25+
W = neginvdtγ*mass_matrix + J
26+
else
27+
W = @.. neginvdtγ*mass_matrix .+ J
28+
end
2429
end
2530
f₀ = f(uprev, p, t)
26-
k₁ = W \ (@.. f₀ + dtγ * dT)
31+
k₁ = _reshape(W \ _vec((f₀ + dtγ * dT)), axes(uprev)) * neginvdtγ
2732
tmp = @.. uprev + dto2 * k₁
2833
f₁ = f(tmp, p, t + dto2)
29-
k₂ = (W \ (f₁ - k₁)) + k₁
34+
if mass_matrix === I
35+
k₂ = _reshape(W \ _vec(f₁ - k₁), axes(uprev))
36+
else
37+
k₂ = _reshape(W \ _vec(f₁ - mass_matrix * k₁), axes(uprev))
38+
end
39+
k₂ = @.. k₂ * neginvdtγ + k₁
3040
copyat_or_push!(k, 1, k₁)
3141
copyat_or_push!(k, 2, k₂)
3242
end
@@ -46,6 +56,7 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p,
4656
sizeu = size(u)
4757
mass_matrix = f.mass_matrix
4858
dtγ = dt * d
59+
neginvdtγ = -inv(dtγ)
4960
dto2 = dt / 2
5061

5162
@.. linsolve_tmp=@muladd fsalfirst + dtγ * dT
@@ -61,10 +72,9 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p,
6172

6273
vecu = _vec(linres.u)
6374
veck₁ = _vec(k₁)
75+
@.. veck₁ = vecu * neginvdtγ
6476

65-
@.. broadcast=false veck₁=-vecu
66-
67-
@.. broadcast=false tmp=uprev + dto2 * k₁
77+
@.. tmp=uprev + dto2 * k₁
6878
f(f₁, tmp, p, t + dto2)
6979

7080
if mass_matrix === I
@@ -73,16 +83,14 @@ function _ode_addsteps!(k, t, uprev, u, dt, f, p,
7383
mul!(_vec(tmp), mass_matrix, _vec(k₁))
7484
end
7585

76-
@.. broadcast=false linsolve_tmp=f₁ - tmp
86+
@.. linsolve_tmp = f₁ - tmp
7787

7888
linres = dolinsolve(cache, linres.cache; b = _vec(linsolve_tmp),
7989
reltol = cache.reltol)
8090
vecu = _vec(linres.u)
81-
veck2 = _vec(k₂)
82-
83-
@.. broadcast=false veck2=-vecu
91+
veck₂ = _vec(k₂)
8492

85-
@.. broadcast=false k₂+=k
93+
@.. veck₂ = vecu * neginvdtγ + veck
8694

8795
copyat_or_push!(k, 1, k₁)
8896
copyat_or_push!(k, 2, k₂)

0 commit comments

Comments
 (0)