Skip to content

Commit 64ed9d2

Browse files
Merge pull request #2634 from oscardssmith/os/fix-Rosenbrock-allocations
fix rosenbrock allocations
2 parents c321a87 + e131cf1 commit 64ed9d2

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_caches.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ end
1212

1313
# Shampine's Low-order Rosenbrocks
1414

15-
mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabType,
15+
mutable struct RosenbrockCache{uType, rateType, tabType, uNoUnitsType, JType, WType, TabType,
1616
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <:
1717
RosenbrockMutableCache
1818
u::uType
@@ -21,6 +21,8 @@ mutable struct RosenbrockCache{uType, rateType, uNoUnitsType, JType, WType, TabT
2121
du::rateType
2222
du1::rateType
2323
du2::rateType
24+
dtC::Matrix{tabType}
25+
dtd::Vector{tabType}
2426
ks::Vector{rateType}
2527
fsalfirst::rateType
2628
fsallast::rateType
@@ -761,6 +763,9 @@ function alg_cache(
761763
du1 = zero(rate_prototype)
762764
du2 = zero(rate_prototype)
763765

766+
dtC = similar(tab.C)
767+
dtd = similar(tab.d)
768+
764769
# Initialize other variables
765770
fsalfirst = zero(rate_prototype)
766771
fsallast = zero(rate_prototype)
@@ -795,7 +800,7 @@ function alg_cache(
795800

796801
# Return the cache struct with vectors
797802
RosenbrockCache(
798-
u, uprev, dense, du, du1, du2, ks, fsalfirst, fsallast,
803+
u, uprev, dense, du, du1, du2, dtC, dtd, ks, fsalfirst, fsallast,
799804
dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
800805
linsolve, jac_config, grad_config, reltol, alg,
801806
alg.step_limiter!, alg.stage_limiter!, size(tab.H, 1))

lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_perform_step.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1318,7 +1318,7 @@ end
13181318

13191319
@muladd function perform_step!(integrator, cache::RosenbrockCache, repeat_step = false)
13201320
(; t, dt, uprev, u, f, p) = integrator
1321-
(; du, du1, du2, dT, J, W, uf, tf, ks, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter!) = cache
1321+
(; du, du1, du2, dT, dtC, dtd, J, W, uf, tf, ks, linsolve_tmp, jac_config, atmp, weight, stage_limiter!, step_limiter!) = cache
13221322
(; A, C, gamma, c, d, H) = cache.tab
13231323

13241324
# Assignments
@@ -1327,8 +1327,8 @@ end
13271327
mass_matrix = integrator.f.mass_matrix
13281328

13291329
# Precalculations
1330-
dtC = C .* inv(dt)
1331-
dtd = dt .* d
1330+
@. dtC = C * inv(dt)
1331+
@. dtd = dt * d
13321332
dtgamma = dt * gamma
13331333

13341334
f(cache.fsalfirst, uprev, p, t)

0 commit comments

Comments
 (0)