Skip to content

Commit efec3c3

Browse files
Merge pull request #222 from CliMA/ck/fix
Callback fixes
2 parents f0c4ed7 + f497dde commit efec3c3

File tree

4 files changed

+87
-43
lines changed

4 files changed

+87
-43
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ClimaTimeSteppers"
22
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
33
authors = ["Climate Modeling Alliance"]
4-
version = "0.7.10"
4+
version = "0.7.11"
55

66
[deps]
77
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

src/nl_solvers/newtons_method.jl

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ struct ForwardDiffStepSize3 <: ForwardDiffStepSize end
130130
Computes the Jacobian-vector product `j(x[n]) * Δx[n]` for a Newton-Krylov
131131
method without directly using the Jacobian `j(x[n])`, and instead only using
132132
`x[n]`, `f(x[n])`, and other function evaluations `f(x′)`. This is done by
133-
calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f)`. The `jΔx` passed to
134-
a Jacobian-free JVP is modified in-place. The `cache` can be obtained with
135-
`allocate_cache(::JacobianFreeJVP, x_prototype)`, where `x_prototype` is
136-
`similar` to `x` (and also to `Δx` and `f`).
133+
calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f, post_implicit!)`.
134+
The `jΔx` passed to a Jacobian-free JVP is modified in-place. The `cache` can
135+
be obtained with `allocate_cache(::JacobianFreeJVP, x_prototype)`, where
136+
`x_prototype` is `similar` to `x` (and also to `Δx` and `f`).
137137
"""
138138
abstract type JacobianFreeJVP end
139139

@@ -151,12 +151,13 @@ end
151151

152152
allocate_cache(::ForwardDiffJVP, x_prototype) = (; x2 = similar(x_prototype), f2 = similar(x_prototype))
153153

154-
function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f)
154+
function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f, post_implicit!)
155155
(; default_step, step_adjustment) = alg
156156
(; x2, f2) = cache
157157
FT = eltype(x)
158158
ε = FT(step_adjustment) * default_step(Δx, x)
159159
@. x2 = x + ε * Δx
160+
isnothing(post_implicit!) || post_implicit!(x2)
160161
f!(f2, x2)
161162
@. jΔx = (f2 - f) / ε
162163
end
@@ -342,10 +343,10 @@ end
342343
Finds an approximation `Δx[n] ≈ j(x[n]) \\ f(x[n])` for Newton's method such
343344
that `‖f(x[n]) - j(x[n]) * Δx[n]‖ ≤ rtol[n] * ‖f(x[n])‖`, where `rtol[n]` is the
344345
value of the forcing term on iteration `n`. This is done by calling
345-
`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, j = nothing)`, where `f` is
346-
`f(x[n])` and, if it is specified, `j` is either `j(x[n])` or an approximation
347-
of `j(x[n])`. The `Δx` passed to a Krylov method is modified in-place. The
348-
`cache` can be obtained with `allocate_cache(::KrylovMethod, x_prototype)`,
346+
`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing)`,
347+
where `f` is `f(x[n])` and, if it is specified, `j` is either `j(x[n])` or an
348+
approximation of `j(x[n])`. The `Δx` passed to a Krylov method is modified in-place.
349+
The `cache` can be obtained with `allocate_cache(::KrylovMethod, x_prototype)`,
349350
where `x_prototype` is `similar` to `x` (and also to `Δx` and `f`).
350351
351352
This is primarily a wrapper for a `Krylov.KrylovSolver` from `Krylov.jl`. In
@@ -427,14 +428,14 @@ function allocate_cache(alg::KrylovMethod, x_prototype)
427428
)
428429
end
429430

430-
function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, j = nothing)
431+
function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, post_implicit!, j = nothing)
431432
(; jacobian_free_jvp, forcing_term, solve_kwargs) = alg
432433
(; disable_preconditioner, debugger) = alg
433434
type = solver_type(alg)
434435
(; jacobian_free_jvp_cache, forcing_term_cache, solver, debugger_cache) = cache
435436
jΔx!(jΔx, Δx) =
436437
isnothing(jacobian_free_jvp) ? mul!(jΔx, j, Δx) :
437-
jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f)
438+
jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f, post_implicit!)
438439
opj = LinearOperator(eltype(x), length(x), length(x), false, false, jΔx!)
439440
M = disable_preconditioner || isnothing(j) || isnothing(jacobian_free_jvp) ? I : j
440441
print_debug!(debugger, debugger_cache, opj, M)
@@ -566,9 +567,25 @@ function allocate_cache(alg::NewtonsMethod, x_prototype, j_prototype = nothing)
566567
)
567568
end
568569

569-
solve_newton!(alg::NewtonsMethod, cache::Nothing, x, f!, j! = nothing, post_implicit! = nothing) = nothing
570-
571-
function solve_newton!(alg::NewtonsMethod, cache, x, f!, j! = nothing, post_implicit! = nothing)
570+
solve_newton!(
571+
alg::NewtonsMethod,
572+
cache::Nothing,
573+
x,
574+
f!,
575+
j! = nothing,
576+
post_implicit! = nothing,
577+
post_implicit_last! = nothing,
578+
) = nothing
579+
580+
function solve_newton!(
581+
alg::NewtonsMethod,
582+
cache,
583+
x,
584+
f!,
585+
j! = nothing,
586+
post_implicit! = nothing,
587+
post_implicit_last! = nothing,
588+
)
572589
(; max_iters, update_j, krylov_method, convergence_checker, verbose) = alg
573590
(; krylov_method_cache, convergence_checker_cache) = cache
574591
(; Δx, f, j) = cache
@@ -588,16 +605,20 @@ function solve_newton!(alg::NewtonsMethod, cache, x, f!, j! = nothing, post_impl
588605
ldiv!(Δx, j, f)
589606
end
590607
else
591-
solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, j)
608+
solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, post_implicit!, j)
592609
end
593610
is_verbose(verbose) && @info "Newton iteration $n: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))"
594611

595612
x .-= Δx
596-
isnothing(post_implicit!) || post_implicit!(x)
597613
# Update x[n] with Δx[n - 1], and exit the loop if Δx[n] is not needed.
598614
# Check for convergence if necessary.
599615
if is_converged!(convergence_checker, convergence_checker_cache, x, Δx, n)
616+
isnothing(post_implicit_last!) || post_implicit_last!(x)
600617
break
618+
elseif n == max_iters
619+
isnothing(post_implicit_last!) || post_implicit_last!(x)
620+
else
621+
isnothing(post_implicit!) || post_implicit!(x)
601622
end
602623
if is_verbose(verbose) && n == max_iters
603624
@warn "Newton's method did not converge within $n iterations: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))"

src/solvers/imex_ark.jl

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,11 @@ end
4747

4848
step_u!(integrator, cache::IMEXARKCache) = step_u!(integrator, cache, integrator.sol.prob.f, integrator.alg.name)
4949

50-
include("hard_coded_ars343.jl")
50+
# include("hard_coded_ars343.jl")
5151
# generic fallback
5252
function step_u!(integrator, cache::IMEXARKCache, f, name)
5353
(; u, p, t, dt, alg) = integrator
54+
(; post_explicit!, post_implicit!) = f
5455
(; T_lim!, T_exp!, T_imp!, lim!, dss!) = f
5556
(; tableau, newtons_method) = alg
5657
(; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = tableau
@@ -114,11 +115,14 @@ function step_u!(integrator, cache::IMEXARKCache, f, name)
114115
dss!(U, p, t_exp)
115116
end
116117

117-
if !isnothing(T_imp!) && !iszero(a_imp[i, i]) # Implicit solve
118+
if !(!isnothing(T_imp!) && !iszero(a_imp[i, i])) # Implicit solve
119+
post_explicit!(U, p, t_imp)
120+
else
118121
@assert !isnothing(newtons_method)
119122
NVTX.@range "temp = U" color = colorant"yellow" begin
120123
@. temp = U
121124
end
125+
post_explicit!(U, p, t_imp)
122126
# TODO: can/should we remove these closures?
123127
implicit_equation_residual! =
124128
(residual, Ui) -> begin
@@ -130,6 +134,18 @@ function step_u!(integrator, cache::IMEXARKCache, f, name)
130134
end
131135
end
132136
implicit_equation_jacobian! = (jacobian, Ui) -> T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
137+
call_post_implicit! = Ui -> begin
138+
post_implicit!(Ui, p, t_imp)
139+
end
140+
call_post_implicit_last! =
141+
Ui -> begin
142+
if (!all(iszero, a_imp[:, i]) || !iszero(b_imp[i])) && !iszero(a_imp[i, i])
143+
# If T_imp[i] is being treated implicitly, ensure that it
144+
# exactly satisfies the implicit equation.
145+
@. T_imp[i] = (Ui - temp) / (dt * a_imp[i, i])
146+
end
147+
post_implicit!(Ui, p, t_imp)
148+
end
133149

134150
NVTX.@range "solve_newton!" color = colorant"yellow" begin
135151
solve_newton!(
@@ -138,6 +154,8 @@ function step_u!(integrator, cache::IMEXARKCache, f, name)
138154
U,
139155
implicit_equation_residual!,
140156
implicit_equation_jacobian!,
157+
call_post_implicit!,
158+
call_post_implicit_last!,
141159
)
142160
end
143161
end
@@ -147,19 +165,11 @@ function step_u!(integrator, cache::IMEXARKCache, f, name)
147165
# tendency only acts in the vertical direction).
148166

149167
if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i])
150-
if !isnothing(T_imp!)
151-
if iszero(a_imp[i, i])
152-
# If its coefficient is 0, T_imp[i] is effectively being
153-
# treated explicitly.
154-
NVTX.@range "T_imp!" color = colorant"yellow" begin
155-
T_imp!(T_imp[i], U, p, t_imp)
156-
end
157-
else
158-
# If T_imp[i] is being treated implicitly, ensure that it
159-
# exactly satisfies the implicit equation.
160-
NVTX.@range "T_imp=(U-temp)/(dt*a_imp)" color = colorant"yellow" begin
161-
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
162-
end
168+
if iszero(a_imp[i, i]) && !isnothing(T_imp!)
169+
# If its coefficient is 0, T_imp[i] is effectively being
170+
# treated explicitly.
171+
NVTX.@range "T_imp!" color = colorant"yellow" begin
172+
T_imp!(T_imp[i], U, p, t_imp)
163173
end
164174
end
165175
end

src/solvers/imex_ssprk.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ step_u!(integrator, cache::IMEXSSPRKCache) = step_u!(integrator, cache, integrat
5656

5757
function step_u!(integrator, cache::IMEXSSPRKCache, f, name)
5858
(; u, p, t, dt, alg) = integrator
59+
(; post_explicit!, post_implicit!) = f
5960
(; T_lim!, T_exp!, T_imp!, lim!, dss!) = f
6061
(; tableau, newtons_method) = alg
6162
(; a_imp, b_imp, c_exp, c_imp) = tableau
@@ -104,21 +105,39 @@ function step_u!(integrator, cache::IMEXSSPRKCache, f, name)
104105
end
105106
end
106107

107-
if !isnothing(T_imp!) && !iszero(a_imp[i, i]) # Implicit solve
108+
if !(!isnothing(T_imp!) && !iszero(a_imp[i, i])) # Implicit solve
109+
post_explicit!(U, p, t_imp)
110+
else
108111
@assert !isnothing(newtons_method)
109112
@. temp = U
113+
post_explicit!(U, p, t_imp)
110114
# TODO: can/should we remove these closures?
111115
implicit_equation_residual! = (residual, Ui) -> begin
112116
T_imp!(residual, Ui, p, t_imp)
113117
@. residual = temp + dt * a_imp[i, i] * residual - Ui
114118
end
115119
implicit_equation_jacobian! = (jacobian, Ui) -> T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
120+
call_post_implicit! = Ui -> begin
121+
post_implicit!(Ui, p, t_imp)
122+
end
123+
call_post_implicit_last! =
124+
Ui -> begin
125+
if (!all(iszero, a_imp[:, i]) || !iszero(b_imp[i])) && !iszero(a_imp[i, i])
126+
# If T_imp[i] is being treated implicitly, ensure that it
127+
# exactly satisfies the implicit equation.
128+
@. T_imp[i] = (Ui - temp) / (dt * a_imp[i, i])
129+
end
130+
post_implicit!(Ui, p, t_imp)
131+
end
132+
116133
solve_newton!(
117134
newtons_method,
118135
newtons_method_cache,
119136
U,
120137
implicit_equation_residual!,
121138
implicit_equation_jacobian!,
139+
call_post_implicit!,
140+
call_post_implicit_last!,
122141
)
123142
end
124143

@@ -127,16 +146,10 @@ function step_u!(integrator, cache::IMEXSSPRKCache, f, name)
127146
# tendency only acts in the vertical direction).
128147

129148
if !all(iszero, a_imp[:, i]) || !iszero(b_imp[i])
130-
if !isnothing(T_imp!)
131-
if iszero(a_imp[i, i])
132-
# If its coefficient is 0, T_imp[i] is effectively being
133-
# treated explicitly.
134-
T_imp!(T_imp[i], U, p, t_imp)
135-
else
136-
# If T_imp[i] is being treated implicitly, ensure that it
137-
# exactly satisfies the implicit equation.
138-
@. T_imp[i] = (U - temp) / (dt * a_imp[i, i])
139-
end
149+
if iszero(a_imp[i, i]) && !isnothing(T_imp!)
150+
# If its coefficient is 0, T_imp[i] is effectively being
151+
# treated explicitly.
152+
T_imp!(T_imp[i], U, p, t_imp)
140153
end
141154
end
142155

0 commit comments

Comments
 (0)