Skip to content

Commit af69a3d

Browse files
authored
Merge pull request #2 from Theozeud/qprk_fsal
Add fsal property
2 parents 72f2cb5 + 29c1331 commit af69a3d

File tree

3 files changed

+32
-51
lines changed

3 files changed

+32
-51
lines changed

src/alg_utils.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,6 @@ isfsal(alg::SSPRK932) = false
7676
isfsal(alg::SSPRK54) = false
7777
isfsal(alg::SSPRK104) = false
7878

79-
isfsal(alg::QPRK98) = false
80-
8179
get_current_isfsal(alg, cache) = isfsal(alg)
8280

8381
# evaluates f(t[i])

src/caches/qprk_caches.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ struct QPRK98ConstantCache <: OrdinaryDiffEqConstantCache end
33
@cache struct QPRK98Cache{uType, rateType, uNoUnitsType, StageLimiter, StepLimiter,Thread} <: OrdinaryDiffEqMutableCache
44
u::uType
55
uprev::uType
6-
k1::rateType
6+
fsalfirst::rateType
77
k2::rateType
88
k3::rateType
99
k4::rateType
@@ -22,6 +22,7 @@ struct QPRK98ConstantCache <: OrdinaryDiffEqConstantCache end
2222
utilde::uType
2323
tmp::uType
2424
atmp::uNoUnitsType
25+
k::rateType
2526
stage_limiter!::StageLimiter
2627
step_limiter!::StepLimiter
2728
thread::Thread
@@ -48,9 +49,10 @@ function alg_cache(alg::QPRK98, u, rate_prototype, ::Type{uEltypeNoUnits}, ::Typ
4849
utilde = zero(u)
4950
tmp = zero(u)
5051
atmp = similar(u, uEltypeNoUnits)
52+
k = zero(rate_prototype)
5153
recursivefill!(atmp, false)
5254
QPRK98Cache(u, uprev, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15,
53-
k16, utilde, tmp, atmp, alg.stage_limiter!, alg.step_limiter!,
55+
k16, utilde, tmp, atmp, k, alg.stage_limiter!, alg.step_limiter!,
5456
alg.thread)
5557
end
5658

src/perform_step/qprk_perform_step.jl

Lines changed: 28 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
function initialize!(integrator, ::QPRK98ConstantCache)
2-
integrator.kshortsize = 16
2+
integrator.fsalfirst = integrator.f(integrator.uprev, integrator.p, integrator.t) # Pre-start fsal
3+
integrator.stats.nf += 1
4+
integrator.kshortsize = 2
35
integrator.k = typeof(integrator.k)(undef, integrator.kshortsize)
46

5-
@inbounds for i in eachindex(integrator.k)
6-
integrator.k[i] = zero(integrator.uprev) ./ oneunit(integrator.t)
7-
end
7+
# Avoid undefined entries if k is an array of arrays
8+
integrator.fsallast = zero(integrator.fsalfirst)
9+
integrator.k[1] = integrator.fsalfirst
10+
integrator.k[2] = integrator.fsallast
811
end
912

1013
@muladd function perform_step!(integrator,::QPRK98ConstantCache, repeat_step=false)
@@ -13,7 +16,7 @@ end
1316
T2 = constvalue(typeof(one(t)))
1417
@OnDemandTableauExtract QPRK98Tableau T T2
1518

16-
k1 = f(uprev, p, t)
19+
k1 = integrator.fsalfirst
1720
k2 = f(uprev + b21 * k1 * dt, p, t + d2 * dt)
1821
k3 = f(uprev + dt*(b31 * k1 + b32 * k2), p, t + d3 * dt)
1922
k4 = f(uprev + dt*(b41 * k1 + b43 * k3), p, t + d4 * dt)
@@ -29,7 +32,7 @@ end
2932
k14 = f(uprev + dt * (b14_1 * k1 + b14_6 * k6 + b14_7 * k7 + b14_8 * k8 + b14_9 * k9 + b14_10 * k10 + b14_11 * k11 + b14_12 * k12 + b14_13 * k13), p, t + d14 * dt)
3033
k15 = f(uprev + dt * (b15_1 * k1 + b15_6 * k6 + b15_7 * k7 + b15_8 * k8 + b15_9 * k9 + b15_10 * k10 + b15_11 * k11 + b15_12 * k12 + b15_13 * k13 + b15_14 * k14), p, t + dt)
3134
k16 = f(uprev + dt * (b16_1 * k1 + b16_6 * k6 + b16_7 * k7 + b16_8 * k8 + b16_9 * k9 + b16_10 * k10 + b16_11 * k11 + b16_12 * k12 + b16_13 * k13 + b16_14 * k14), p, t + dt)
32-
35+
integrator.stats.nf += 15
3336
u = uprev + dt * (w1 * k1 + w8 * k8 + w9 * k9 + w10 * k10 + w11 * k11 + w12 * k12 + w13 * k13 + w14 * k14 + w15 * k15 + w16 * k16)
3437

3538
if integrator.opts.adaptive
@@ -38,58 +41,32 @@ end
3841
integrator.opts.reltol, integrator.opts.internalnorm, t)
3942
integrator.EEst = integrator.opts.internalnorm(atmp, t)
4043
end
41-
42-
integrator.k[1] = k1
43-
integrator.k[2] = k2
44-
integrator.k[3] = k3
45-
integrator.k[4] = k4
46-
integrator.k[5] = k5
47-
integrator.k[6] = k6
48-
integrator.k[7] = k7
49-
integrator.k[8] = k8
50-
integrator.k[9] = k9
51-
integrator.k[10] = k10
52-
integrator.k[11] = k11
53-
integrator.k[12] = k12
54-
integrator.k[13] = k13
55-
integrator.k[14] = k14
56-
integrator.k[15] = k15
57-
integrator.k[16] = k16
44+
integrator.fsallast = f(u, p, t + dt)
45+
integrator.stats.nf += 1
46+
integrator.k[1] = integrator.fsalfirst
47+
integrator.k[2] = integrator.fsallast
5848
integrator.u = u
5949
end
6050

6151

62-
6352
function initialize!(integrator, cache::QPRK98Cache)
64-
@unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16 = cache
65-
@unpack k = integrator
66-
integrator.kshortsize = 16
67-
resize!(k, integrator.kshortsize)
68-
k[1] = k1
69-
k[2] = k2
70-
k[3] = k3
71-
k[4] = k4
72-
k[5] = k5
73-
k[6] = k6
74-
k[7] = k7
75-
k[8] = k8
76-
k[9] = k9
77-
k[10] = k10
78-
k[11] = k11
79-
k[12] = k12
80-
k[13] = k13
81-
k[14] = k14
82-
k[15] = k15
83-
k[16] = k16
53+
integrator.fsalfirst = cache.fsalfirst
54+
integrator.fsallast = cache.k
55+
integrator.kshortsize = 2
56+
resize!(integrator.k, integrator.kshortsize)
57+
integrator.k[1] = integrator.fsalfirst
58+
integrator.k[2] = integrator.fsallast
59+
integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t) # Pre-start fsal
60+
integrator.stats.nf += 1
8461
end
8562

8663
@muladd function perform_step!(integrator, cache::QPRK98Cache, repeat_step=false)
8764
@unpack t, dt, uprev, u, f, p = integrator
8865
T = constvalue(recursive_unitless_bottom_eltype(u))
8966
T2 = constvalue(typeof(one(t)))
9067
@OnDemandTableauExtract QPRK98Tableau T T2
91-
@unpack k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16, utilde, tmp, atmp, stage_limiter!, step_limiter!, thread = cache
92-
68+
@unpack fsalfirst, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14, k15, k16, utilde, tmp, atmp, k, stage_limiter!, step_limiter!, thread = cache
69+
k1 = fsalfirst
9370
f(k1, uprev, p, t)
9471
@.. broadcast=false thread=thread tmp = uprev + dt * b21 * k1
9572
stage_limiter!(tmp, integrator, p, t + d2 * dt)
@@ -137,10 +114,12 @@ end
137114
stage_limiter!(u, integrator, p, t + dt)
138115
step_limiter!(u, integrator, p, t + dt)
139116
f(k16, tmp, p, t + dt)
140-
117+
141118
integrator.stats.nf += 16
142119

143120
@.. broadcast=false thread=thread u=uprev + dt * (w1 * k1 + w8 * k8 + w9 * k9 + w10 * k10 + w11 * k11 + w12 * k12 + w13 * k13 + w14 * k14 + w15 * k15 + w16 * k16)
121+
stage_limiter!(u, integrator, p, t + dt)
122+
step_limiter!(u, integrator, p, t + dt)
144123

145124
if integrator.opts.adaptive
146125
@.. broadcast=false thread=thread utilde = dt * (ϵ1 * k1 + ϵ8 * k8 +
@@ -154,5 +133,7 @@ end
154133
thread)
155134
integrator.EEst = integrator.opts.internalnorm(atmp, t)
156135
end
136+
f(k, u, p, t + dt)
137+
integrator.stats.nf += 1
157138
return nothing
158139
end

0 commit comments

Comments
 (0)