Skip to content

Commit 9048ca6

Browse files
committed
Rename post_explicit! to cache! and post_implicit! to cache_imp!
1 parent 96b7067 commit 9048ca6

File tree

11 files changed

+102
-104
lines changed

11 files changed

+102
-104
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ClimaTimeSteppers"
22
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
33
authors = ["Climate Modeling Alliance"]
4-
version = "0.7.40"
4+
version = "0.8.0"
55

66
[deps]
77
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"

docs/src/api/ode_solvers.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ CurrentModule = ClimaTimeSteppers
77
## Interface
88

99
```@docs
10+
ClimaODEFunction
1011
AbstractAlgorithmConstraint
1112
Unconstrained
1213
SSP

ext/ClimaTimeSteppersBenchmarkToolsExt.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ n_calls_per_step(::CTS.ARS343, max_newton_iters) = Dict(
3535
"T_exp_T_lim!" => 4,
3636
"lim!" => 4,
3737
"dss!" => 4,
38-
"post_explicit!" => 3,
39-
"post_implicit!" => 4,
38+
"cache!" => 3,
39+
"cache_imp!" => 4,
4040
"step!" => 1,
4141
)
4242
function n_calls_per_step(alg::CTS.RosenbrockAlgorithm)
@@ -47,8 +47,8 @@ function n_calls_per_step(alg::CTS.RosenbrockAlgorithm)
4747
"T_exp_T_lim!" => CTS.n_stages(alg.tableau),
4848
"lim!" => 0,
4949
"dss!" => CTS.n_stages(alg.tableau),
50-
"post_explicit!" => 0,
51-
"post_implicit!" => CTS.n_stages(alg.tableau),
50+
"cache!" => 0,
51+
"cache_imp!" => CTS.n_stages(alg.tableau),
5252
"step!" => 1,
5353
)
5454
end
@@ -59,8 +59,7 @@ function maybe_push!(trials₀, name, f!, args, kwargs, only)
5959
end
6060
end
6161

62-
const allowed_names =
63-
["Wfact", "ldiv!", "T_imp!", "T_exp_T_lim!", "lim!", "dss!", "post_explicit!", "post_implicit!", "step!"]
62+
const allowed_names = ["Wfact", "ldiv!", "T_imp!", "T_exp_T_lim!", "lim!", "dss!", "cache!", "cache_imp!", "step!"]
6463

6564
"""
6665
benchmark_step(
@@ -89,8 +88,8 @@ Benchmark a DistributedODEIntegrator given:
8988
- "T_exp_T_lim!"
9089
- "lim!"
9190
- "dss!"
92-
- "post_explicit!"
93-
- "post_implicit!"
91+
- "cache!"
92+
- "cache_imp!"
9493
- "step!"
9594
"""
9695
function CTS.benchmark_step(
@@ -123,8 +122,8 @@ function CTS.benchmark_step(
123122
maybe_push!(trials₀, "T_exp_T_lim!", remaining_fun(integrator), remaining_args(integrator), kwargs, only)
124123
maybe_push!(trials₀, "lim!", f.lim!, (Xlim, p, t, u), kwargs, only)
125124
maybe_push!(trials₀, "dss!", f.dss!, (u, p, t), kwargs, only)
126-
maybe_push!(trials₀, "post_explicit!", f.post_explicit!, (u, p, t), kwargs, only)
127-
maybe_push!(trials₀, "post_implicit!", f.post_implicit!, (u, p, t), kwargs, only)
125+
maybe_push!(trials₀, "cache!", f.cache!, (u, p, t), kwargs, only)
126+
maybe_push!(trials₀, "cache_imp!", f.cache_imp!, (u, p, t), kwargs, only)
128127
maybe_push!(trials₀, "step!", SciMLBase.step!, (integrator, ), kwargs, only)
129128
#! format: on
130129

src/functions.jl

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,26 +4,51 @@ export ClimaODEFunction, ForwardEulerODEFunction
44

55
abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end
66

7-
struct ClimaODEFunction{TEL, TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction
7+
"""
8+
ClimaODEFunction(; T_imp!, [dss!], [cache!], [cache_imp!])
9+
ClimaODEFunction(; T_exp!, T_lim!, [T_imp!], [lim!], [dss!], [cache!], [cache_imp!])
10+
ClimaODEFunction(; T_exp_lim!, [T_imp!], [lim!], [dss!], [cache!], [cache_imp!])
11+
12+
Container for all functions used to advance through a timestep:
13+
- `T_imp!(T_imp, u, p, t)`: sets the implicit tendency
14+
- `T_exp!(T_exp, u, p, t)`: sets the component of the explicit tendency that
15+
is not passed through the limiter
16+
- `T_lim!(T_lim, u, p, t)`: sets the component of the explicit tendency that
17+
is passed through the limiter
18+
- `T_exp_lim!(T_exp, T_lim, u, p, t)`: fused alternative to the separate
19+
functions `T_exp!` and `T_lim!`
20+
- `lim!(u, p, t, u_ref)`: applies the limiter to every state `u` that has
21+
been incremented from `u_ref` by the explicit tendency component `T_lim!`
22+
- `dss!(u, p, t)`: applies direct stiffness summation to every state `u`,
23+
except for intermediate states generated within the implicit solver
24+
- `cache!(u, p, t)`: updates the cache `p` to reflect the state `u` before
25+
the first timestep and on every subsequent timestepping stage
26+
- `cache_imp!(u, p, t)`: updates the components of the cache `p` that are
27+
required to evaluate `T_imp!` and its Jacobian within the implicit solver
28+
By default, `lim!`, `dss!`, and `cache!` all do nothing, and `cache_imp!` is
29+
identical to `cache!`. Any of the tendency functions can be set to `nothing` in
30+
order to avoid corresponding allocations in the integrator.
31+
"""
32+
struct ClimaODEFunction{TEL, TL, TE, TI, L, D, C, CI} <: AbstractClimaODEFunction
833
T_exp_T_lim!::TEL
934
T_lim!::TL
1035
T_exp!::TE
1136
T_imp!::TI
1237
lim!::L
1338
dss!::D
14-
post_explicit!::PE
15-
post_implicit!::PI
39+
cache!::C
40+
cache_imp!::CI
1641
function ClimaODEFunction(;
17-
T_exp_T_lim! = nothing, # nothing or (uₜ_exp, uₜ_lim, u, p, t) -> ...
18-
T_lim! = nothing, # nothing or (uₜ, u, p, t) -> ...
19-
T_exp! = nothing, # nothing or (uₜ, u, p, t) -> ...
20-
T_imp! = nothing, # nothing or (uₜ, u, p, t) -> ...
42+
T_exp_T_lim! = nothing,
43+
T_lim! = nothing,
44+
T_exp! = nothing,
45+
T_imp! = nothing,
2146
lim! = (u, p, t, u_ref) -> nothing,
2247
dss! = (u, p, t) -> nothing,
23-
post_explicit! = (u, p, t) -> nothing,
24-
post_implicit! = (u, p, t) -> nothing,
48+
cache! = (u, p, t) -> nothing,
49+
cache_imp! = cache!,
2550
)
26-
args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, post_explicit!, post_implicit!)
51+
args = (T_exp_T_lim!, T_lim!, T_exp!, T_imp!, lim!, dss!, cache!, cache_imp!)
2752

2853
if !isnothing(T_exp_T_lim!)
2954
@assert isnothing(T_exp!) "`T_exp_T_lim!` was passed, `T_exp!` must be `nothing`"

src/integrators.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ function DiffEqBase.__init(
147147
tdir,
148148
)
149149
if prob.f isa ClimaODEFunction
150-
(; post_explicit!) = prob.f
151-
isnothing(post_explicit!) || post_explicit!(u0, p, t0)
150+
(; cache!) = prob.f
151+
isnothing(cache!) || cache!(u0, p, t0)
152152
end
153153
DiffEqBase.initialize!(callback, u0, t0, integrator)
154154
return integrator

src/nl_solvers/newtons_method.jl

Lines changed: 13 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ struct ForwardDiffStepSize3 <: ForwardDiffStepSize end
130130
Computes the Jacobian-vector product `j(x[n]) * Δx[n]` for a Newton-Krylov
131131
method without directly using the Jacobian `j(x[n])`, and instead only using
132132
`x[n]`, `f(x[n])`, and other function evaluations `f(x′)`. This is done by
133-
calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f, post_implicit!)`.
133+
calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f, prepare_for_f!)`.
134134
The `jΔx` passed to a Jacobian-free JVP is modified in-place. The `cache` can
135135
be obtained with `allocate_cache(::JacobianFreeJVP, x_prototype)`, where
136136
`x_prototype` is `similar` to `x` (and also to `Δx` and `f`).
@@ -151,13 +151,13 @@ end
151151

152152
allocate_cache(::ForwardDiffJVP, x_prototype) = (; x2 = zero(x_prototype), f2 = zero(x_prototype))
153153

154-
function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f, post_implicit!)
154+
function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f, prepare_for_f!)
155155
(; default_step, step_adjustment) = alg
156156
(; x2, f2) = cache
157157
FT = eltype(x)
158158
ε = FT(step_adjustment) * default_step(Δx, x)
159159
@. x2 = x + ε * Δx
160-
isnothing(post_implicit!) || post_implicit!(x2)
160+
isnothing(prepare_for_f!) || prepare_for_f!(x2)
161161
f!(f2, x2)
162162
@. jΔx = (f2 - f) / ε
163163
end
@@ -343,7 +343,7 @@ end
343343
Finds an approximation `Δx[n] ≈ j(x[n]) \\ f(x[n])` for Newton's method such
344344
that `‖f(x[n]) - j(x[n]) * Δx[n]‖ ≤ rtol[n] * ‖f(x[n])‖`, where `rtol[n]` is the
345345
value of the forcing term on iteration `n`. This is done by calling
346-
`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing)`,
346+
`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, prepare_for_f!, j = nothing)`,
347347
where `f` is `f(x[n])` and, if it is specified, `j` is either `j(x[n])` or an
348348
approximation of `j(x[n])`. The `Δx` passed to a Krylov method is modified in-place.
349349
The `cache` can be obtained with `allocate_cache(::KrylovMethod, x_prototype)`,
@@ -428,14 +428,14 @@ function allocate_cache(alg::KrylovMethod, x_prototype)
428428
)
429429
end
430430

431-
NVTX.@annotate function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing)
431+
NVTX.@annotate function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, prepare_for_f!, j = nothing)
432432
(; jacobian_free_jvp, forcing_term, solve_kwargs) = alg
433433
(; disable_preconditioner, debugger) = alg
434434
type = solver_type(alg)
435435
(; jacobian_free_jvp_cache, forcing_term_cache, solver, debugger_cache) = cache
436436
jΔx!(jΔx, Δx) =
437437
isnothing(jacobian_free_jvp) ? mul!(jΔx, j, Δx) :
438-
jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f, post_implicit!)
438+
jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f, prepare_for_f!)
439439
opj = LinearOperator(eltype(x), length(x), length(x), false, false, jΔx!)
440440
M = disable_preconditioner || isnothing(j) || isnothing(jacobian_free_jvp) ? I : j
441441
print_debug!(debugger, debugger_cache, opj, M)
@@ -567,25 +567,9 @@ function allocate_cache(alg::NewtonsMethod, x_prototype, j_prototype = nothing)
567567
)
568568
end
569569

570-
solve_newton!(
571-
alg::NewtonsMethod,
572-
cache::Nothing,
573-
x,
574-
f!,
575-
j! = nothing,
576-
post_implicit! = nothing,
577-
post_implicit_last! = nothing,
578-
) = nothing
579-
580-
NVTX.@annotate function solve_newton!(
581-
alg::NewtonsMethod,
582-
cache,
583-
x,
584-
f!,
585-
j! = nothing,
586-
post_implicit! = nothing,
587-
post_implicit_last! = nothing,
588-
)
570+
solve_newton!(alg::NewtonsMethod, cache::Nothing, x, f!, j! = nothing, prepare_for_f! = nothing) = nothing
571+
572+
NVTX.@annotate function solve_newton!(alg::NewtonsMethod, cache, x, f!, j! = nothing, prepare_for_f! = nothing)
589573
(; max_iters, update_j, krylov_method, convergence_checker, verbose) = alg
590574
(; krylov_method_cache, convergence_checker_cache) = cache
591575
(; Δx, f, j) = cache
@@ -605,22 +589,18 @@ NVTX.@annotate function solve_newton!(
605589
ldiv!(Δx, j, f)
606590
end
607591
else
608-
solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, post_implicit!, j)
592+
solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, prepare_for_f!, j)
609593
end
610594
is_verbose(verbose) && @info "Newton iteration $n: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))"
611595

612596
x .-= Δx
613597
# Update x[n] with Δx[n - 1], and exit the loop if Δx[n] is not needed.
614598
# Check for convergence if necessary.
615599
if is_converged!(convergence_checker, convergence_checker_cache, x, Δx, n)
616-
isnothing(post_implicit_last!) || post_implicit_last!(x)
617600
break
618-
elseif n == max_iters
619-
isnothing(post_implicit_last!) || post_implicit_last!(x)
620-
else
621-
isnothing(post_implicit!) || post_implicit!(x)
622-
end
623-
if is_verbose(verbose) && n == max_iters
601+
elseif n < max_iters
602+
isnothing(prepare_for_f!) || prepare_for_f!(x)
603+
elseif is_verbose(verbose)
624604
@warn "Newton's method did not converge within $n iterations: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))"
625605
end
626606
end

src/solvers/hard_coded_ars343.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
44
(; u, p, t, dt, sol, alg) = integrator
55
(; f) = sol.prob
66
(; T_imp!, lim!, dss!) = f
7-
(; post_explicit!, post_implicit!) = f
7+
(; cache!, cache_imp!) = f
88
(; tableau, newtons_method) = alg
99
(; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau
1010
(; U, T_lim, T_exp, T_imp, temp, γ, newtons_method_cache) = cache
@@ -35,27 +35,27 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
3535
@. temp = U # used in closures
3636
let i = i
3737
t_imp = t + dt * c_imp[i]
38-
post_implicit!(U, p, t_imp)
38+
cache_imp!(U, p, t_imp)
3939
implicit_equation_residual! = (residual, Ui) -> begin
4040
T_imp!(residual, Ui, p, t_imp)
4141
@. residual = temp + dt * a_imp[i, i] * residual - Ui
4242
end
4343
implicit_equation_jacobian! = (jacobian, Ui) -> begin
4444
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
4545
end
46-
call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp)
46+
call_cache_imp! = Ui -> cache_imp!(Ui, p, t_imp)
4747
solve_newton!(
4848
newtons_method,
4949
newtons_method_cache,
5050
U,
5151
implicit_equation_residual!,
5252
implicit_equation_jacobian!,
53-
call_post_implicit!,
53+
call_cache_imp!,
5454
nothing,
5555
)
5656
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
5757
dss!(U, p, t_imp)
58-
post_explicit!(U, p, t_imp)
58+
cache!(U, p, t_imp)
5959
end
6060
T_lim!(T_lim[i], U, p, t_exp)
6161
T_exp!(T_exp[i], U, p, t_exp)
@@ -69,27 +69,27 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
6969
@. temp = U # used in closures
7070
let i = i
7171
t_imp = t + dt * c_imp[i]
72-
post_implicit!(U, p, t_imp)
72+
cache_imp!(U, p, t_imp)
7373
implicit_equation_residual! = (residual, Ui) -> begin
7474
T_imp!(residual, Ui, p, t_imp)
7575
@. residual = temp + dt * a_imp[i, i] * residual - Ui
7676
end
7777
implicit_equation_jacobian! = (jacobian, Ui) -> begin
7878
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
7979
end
80-
call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp)
80+
call_cache_imp! = Ui -> cache_imp!(Ui, p, t_imp)
8181
solve_newton!(
8282
newtons_method,
8383
newtons_method_cache,
8484
U,
8585
implicit_equation_residual!,
8686
implicit_equation_jacobian!,
87-
call_post_implicit!,
87+
call_cache_imp!,
8888
nothing,
8989
)
9090
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
9191
dss!(U, p, t_imp)
92-
post_explicit!(U, p, t_imp)
92+
cache!(U, p, t_imp)
9393
end
9494
T_lim!(T_lim[i], U, p, t_exp)
9595
T_exp!(T_exp[i], U, p, t_exp)
@@ -108,27 +108,27 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
108108
@. temp = U # used in closures
109109
let i = i
110110
t_imp = t + dt * c_imp[i]
111-
post_implicit!(U, p, t_imp)
111+
cache_imp!(U, p, t_imp)
112112
implicit_equation_residual! = (residual, Ui) -> begin
113113
T_imp!(residual, Ui, p, t_imp)
114114
@. residual = temp + dt * a_imp[i, i] * residual - Ui
115115
end
116116
implicit_equation_jacobian! = (jacobian, Ui) -> begin
117117
T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
118118
end
119-
call_post_implicit! = Ui -> post_implicit!(Ui, p, t_imp)
119+
call_cache_imp! = Ui -> cache_imp!(Ui, p, t_imp)
120120
solve_newton!(
121121
newtons_method,
122122
newtons_method_cache,
123123
U,
124124
implicit_equation_residual!,
125125
implicit_equation_jacobian!,
126-
call_post_implicit!,
126+
call_cache_imp!,
127127
nothing,
128128
)
129129
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
130130
dss!(U, p, t_imp)
131-
post_explicit!(U, p, t_imp)
131+
cache!(U, p, t_imp)
132132
end
133133
T_lim!(T_lim[i], U, p, t_exp)
134134
T_exp!(T_exp[i], U, p, t_exp)
@@ -145,6 +145,6 @@ function step_u!(integrator, cache::IMEXARKCache, ::ARS343)
145145
dt * b_imp[3] * T_imp[3] +
146146
dt * b_imp[4] * T_imp[4]
147147
dss!(u, p, t_final)
148-
post_explicit!(u, p, t_final)
148+
cache!(u, p, t_final)
149149
return u
150150
end

0 commit comments

Comments
 (0)