Skip to content

Commit f9fd04e

Browse files
Introduce and propagate stage limiters for Rosenbrock methods.
1 parent aa44608 commit f9fd04e

File tree

3 files changed

+90
-40
lines changed

3 files changed

+90
-40
lines changed

src/algorithms.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3085,20 +3085,23 @@ for Alg in [
30853085
:Rodas5Pe,
30863086
:Rodas5Pr]
30873087
@eval begin
3088-
struct $Alg{CS, AD, F, P, FDT, ST, CJ, StepLimiter} <:
3088+
struct $Alg{CS, AD, F, P, FDT, ST, CJ, StepLimiter, StageLimiter} <:
30893089
OrdinaryDiffEqRosenbrockAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
30903090
linsolve::F
30913091
precs::P
30923092
step_limiter!::StepLimiter
3093+
stage_limiter!::StageLimiter
30933094
end
30943095
function $Alg(; chunk_size = Val{0}(), autodiff = Val{true}(),
30953096
standardtag = Val{true}(), concrete_jac = nothing,
30963097
diff_type = Val{:forward}, linsolve = nothing,
3097-
precs = DEFAULT_PRECS, step_limiter! = trivial_limiter!)
3098+
precs = DEFAULT_PRECS, step_limiter! = trivial_limiter!,
3099+
stage_limiter! = trivial_limiter!)
30983100
$Alg{_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
30993101
typeof(precs), diff_type, _unwrap_val(standardtag),
3100-
_unwrap_val(concrete_jac), typeof(step_limiter!)}(linsolve,
3101-
precs, step_limiter!)
3102+
_unwrap_val(concrete_jac), typeof(step_limiter!),
3103+
typeof(stage_limiter!)}(linsolve, precs, step_limiter!,
3104+
stage_limiter!)
31023105
end
31033106
end
31043107

src/caches/rosenbrock_caches.jl

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ abstract type RosenbrockMutableCache <: OrdinaryDiffEqMutableCache end
55

66
@cache mutable struct Rosenbrock23Cache{uType, rateType, uNoUnitsType, JType, WType,
77
TabType, TFType, UFType, F, JCType, GCType,
8-
RTolType, A, AV, StepLimiter} <: RosenbrockMutableCache
8+
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
99
u::uType
1010
uprev::uType
1111
k₁::rateType
@@ -33,13 +33,14 @@ abstract type RosenbrockMutableCache <: OrdinaryDiffEqMutableCache end
3333
alg::A
3434
algebraic_vars::AV
3535
step_limiter!::StepLimiter
36+
stage_limiter!::StageLimiter
3637
end
3738

3839
TruncatedStacktraces.@truncate_stacktrace Rosenbrock23Cache 1
3940

4041
@cache mutable struct Rosenbrock32Cache{uType, rateType, uNoUnitsType, JType, WType,
4142
TabType, TFType, UFType, F, JCType, GCType,
42-
RTolType, A, AV, StepLimiter} <: RosenbrockMutableCache
43+
RTolType, A, AV, StepLimiter, StageLimiter} <: RosenbrockMutableCache
4344
u::uType
4445
uprev::uType
4546
k₁::rateType
@@ -67,6 +68,7 @@ TruncatedStacktraces.@truncate_stacktrace Rosenbrock23Cache 1
6768
alg::A
6869
algebraic_vars::AV
6970
step_limiter!::StepLimiter
71+
stage_limiter!::StageLimiter
7072
end
7173

7274
function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -110,7 +112,8 @@ function alg_cache(alg::Rosenbrock23, u, rate_prototype, ::Type{uEltypeNoUnits},
110112
Rosenbrock23Cache(u, uprev, k₁, k₂, k₃, du1, du2, f₁,
111113
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
112114
linsolve_tmp,
113-
linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg.step_limiter!)
115+
linsolve, jac_config, grad_config, reltol, alg, algebraic_vars, alg.step_limiter!,
116+
alg.stage_limiter!)
114117
end
115118

116119
function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -153,7 +156,7 @@ function alg_cache(alg::Rosenbrock32, u, rate_prototype, ::Type{uEltypeNoUnits},
153156

154157
Rosenbrock32Cache(u, uprev, k₁, k₂, k₃, du1, du2, f₁, fsalfirst, fsallast, dT, J, W,
155158
tmp, atmp, weight, tab, tf, uf, linsolve_tmp, linsolve, jac_config,
156-
grad_config, reltol, alg, algebraic_vars, alg.step_limiter!)
159+
grad_config, reltol, alg, algebraic_vars, alg.step_limiter!, alg.stage_limiter!)
157160
end
158161

159162
struct Rosenbrock23ConstantCache{T, TF, UF, JType, WType, F, AD} <:
@@ -232,7 +235,7 @@ end
232235

233236
@cache mutable struct Rosenbrock33Cache{uType, rateType, uNoUnitsType, JType, WType,
234237
TabType, TFType, UFType, F, JCType, GCType,
235-
RTolType, A, StepLimiter} <: RosenbrockMutableCache
238+
RTolType, A, StepLimiter, StageLimiter} <: RosenbrockMutableCache
236239
u::uType
237240
uprev::uType
238241
du::rateType
@@ -260,6 +263,7 @@ end
260263
reltol::RTolType
261264
alg::A
262265
step_limiter!::StepLimiter
266+
stage_limiter!::StageLimiter
263267
end
264268

265269
function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -298,7 +302,8 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits},
298302
Rosenbrock33Cache(u, uprev, du, du1, du2, k1, k2, k3, k4,
299303
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
300304
linsolve_tmp,
301-
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
305+
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
306+
alg.stage_limiter!)
302307
end
303308

304309
function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -316,7 +321,7 @@ function alg_cache(alg::ROS3P, u, rate_prototype, ::Type{uEltypeNoUnits},
316321
end
317322

318323
@cache mutable struct Rosenbrock34Cache{uType, rateType, uNoUnitsType, JType, WType,
319-
TabType, TFType, UFType, F, JCType, GCType, StepLimiter} <:
324+
TabType, TFType, UFType, F, JCType, GCType, StepLimiter, StageLimiter} <:
320325
RosenbrockMutableCache
321326
u::uType
322327
uprev::uType
@@ -343,6 +348,7 @@ end
343348
jac_config::JCType
344349
grad_config::GCType
345350
step_limiter!::StepLimiter
351+
stage_limiter!::StageLimiter
346352
end
347353

348354
function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -382,7 +388,8 @@ function alg_cache(alg::Rodas3, u, rate_prototype, ::Type{uEltypeNoUnits},
382388
Rosenbrock34Cache(u, uprev, du, du1, du2, k1, k2, k3, k4,
383389
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
384390
linsolve_tmp,
385-
linsolve, jac_config, grad_config, alg.step_limiter!)
391+
linsolve, jac_config, grad_config, alg.step_limiter!,
392+
alg.stage_limiter!)
386393
end
387394

388395
struct Rosenbrock34ConstantCache{TF, UF, Tab, JType, WType, F} <:
@@ -460,7 +467,7 @@ struct Rodas3PConstantCache{TF, UF, Tab, JType, WType, F, AD} <: OrdinaryDiffEqC
460467
end
461468

462469
@cache mutable struct Rodas23WCache{uType, rateType, uNoUnitsType, JType, WType, TabType,
463-
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter} <:
470+
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <:
464471
RosenbrockMutableCache
465472
u::uType
466473
uprev::uType
@@ -493,10 +500,11 @@ end
493500
reltol::RTolType
494501
alg::A
495502
step_limiter!::StepLimiter
503+
stage_limiter!::StageLimiter
496504
end
497505

498506
@cache mutable struct Rodas3PCache{uType, rateType, uNoUnitsType, JType, WType, TabType,
499-
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter} <:
507+
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <:
500508
RosenbrockMutableCache
501509
u::uType
502510
uprev::uType
@@ -529,6 +537,7 @@ end
529537
reltol::RTolType
530538
alg::A
531539
step_limiter!::StepLimiter
540+
stage_limiter!::StageLimiter
532541
end
533542

534543
function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -571,7 +580,8 @@ function alg_cache(alg::Rodas23W, u, rate_prototype, ::Type{uEltypeNoUnits},
571580
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
572581
Rodas23WCache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5,
573582
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
574-
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
583+
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
584+
alg.stage_limiter!)
575585
end
576586

577587
TruncatedStacktraces.@truncate_stacktrace Rodas23WCache 1
@@ -615,7 +625,8 @@ function alg_cache(alg::Rodas3P, u, rate_prototype, ::Type{uEltypeNoUnits},
615625
jac_config = build_jac_config(alg, f, uf, du1, uprev, u, tmp, du2)
616626
Rodas3PCache(u, uprev, dense1, dense2, dense3, du, du1, du2, k1, k2, k3, k4, k5,
617627
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
618-
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
628+
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
629+
alg.stage_limiter!)
619630
end
620631

621632
TruncatedStacktraces.@truncate_stacktrace Rodas3PCache 1
@@ -663,7 +674,7 @@ struct Rodas4ConstantCache{TF, UF, Tab, JType, WType, F, AD} <: OrdinaryDiffEqCo
663674
end
664675

665676
@cache mutable struct Rodas4Cache{uType, rateType, uNoUnitsType, JType, WType, TabType,
666-
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter} <:
677+
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <:
667678
RosenbrockMutableCache
668679
u::uType
669680
uprev::uType
@@ -696,6 +707,7 @@ end
696707
reltol::RTolType
697708
alg::A
698709
step_limiter!::StepLimiter
710+
stage_limiter!::StageLimiter
699711
end
700712

701713
function alg_cache(alg::Rodas4, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -739,7 +751,8 @@ function alg_cache(alg::Rodas4, u, rate_prototype, ::Type{uEltypeNoUnits},
739751
Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4,
740752
k5, k6,
741753
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
742-
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
754+
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
755+
alg.stage_limiter!)
743756
end
744757

745758
TruncatedStacktraces.@truncate_stacktrace Rodas4Cache 1
@@ -800,7 +813,8 @@ function alg_cache(alg::Rodas42, u, rate_prototype, ::Type{uEltypeNoUnits},
800813
Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4,
801814
k5, k6,
802815
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
803-
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
816+
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
817+
alg.stage_limiter!)
804818
end
805819

806820
function alg_cache(alg::Rodas42, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -859,7 +873,8 @@ function alg_cache(alg::Rodas4P, u, rate_prototype, ::Type{uEltypeNoUnits},
859873
Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4,
860874
k5, k6,
861875
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
862-
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
876+
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
877+
alg.stage_limiter!)
863878
end
864879

865880
function alg_cache(alg::Rodas4P, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -918,7 +933,8 @@ function alg_cache(alg::Rodas4P2, u, rate_prototype, ::Type{uEltypeNoUnits},
918933
Rodas4Cache(u, uprev, dense1, dense2, du, du1, du2, k1, k2, k3, k4,
919934
k5, k6,
920935
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf, linsolve_tmp,
921-
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
936+
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
937+
alg.stage_limiter!)
922938
end
923939

924940
function alg_cache(alg::Rodas4P2, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -951,7 +967,7 @@ end
951967

952968
@cache mutable struct Rosenbrock5Cache{
953969
uType, rateType, uNoUnitsType, JType, WType, TabType,
954-
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter} <:
970+
TFType, UFType, F, JCType, GCType, RTolType, A, StepLimiter, StageLimiter} <:
955971
RosenbrockMutableCache
956972
u::uType
957973
uprev::uType
@@ -987,6 +1003,7 @@ end
9871003
reltol::RTolType
9881004
alg::A
9891005
step_limiter!::StepLimiter
1006+
stage_limiter!::StageLimiter
9901007
end
9911008

9921009
TruncatedStacktraces.@truncate_stacktrace Rosenbrock5Cache 1
@@ -1036,7 +1053,8 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
10361053
k5, k6, k7, k8,
10371054
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
10381055
linsolve_tmp,
1039-
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
1056+
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
1057+
alg.stage_limiter!)
10401058
end
10411059

10421060
function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
@@ -1099,7 +1117,8 @@ function alg_cache(
10991117
k5, k6, k7, k8,
11001118
fsalfirst, fsallast, dT, J, W, tmp, atmp, weight, tab, tf, uf,
11011119
linsolve_tmp,
1102-
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!)
1120+
linsolve, jac_config, grad_config, reltol, alg, alg.step_limiter!,
1121+
alg.stage_limiter!)
11031122
end
11041123

11051124
function alg_cache(

0 commit comments

Comments
 (0)