Skip to content

Commit 048ed07

Browse files
Merge #119
119: Decompose run! r=charleskawczynski a=charleskawczynski This PR decomposes `run!` into `check_convergence!`, `jvp!` `get_rtol!`, `solve_krylov!` Next, we should remove the `run!` fallback (i.e., fix #118) Co-authored-by: Charles Kawczynski <[email protected]>
2 parents aa67470 + 13f8d82 commit 048ed07

File tree

5 files changed

+34
-24
lines changed

5 files changed

+34
-24
lines changed

src/solvers/convergence_checker.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ using LinearAlgebra: norm
1212
1313
Checks whether a sequence `val[0], val[1], val[2], ...` has converged to some
1414
limit `L`, given the errors `err[iter] = val[iter] .- L`. This is done by
15-
calling `run!(::ConvergenceChecker, cache, val, err, iter)`, where
15+
calling `check_convergence!(::ConvergenceChecker, cache, val, err, iter)`, where
1616
`val = val[iter]` and `err = err[iter]`. If the value of `L` is not known, `err`
1717
can be an approximation of `err[iter]`. The `cache` for a `ConvergenceChecker`
1818
can be obtained with `allocate_cache(::ConvergenceChecker, val_prototype)`,
@@ -68,7 +68,7 @@ function has_component_converged(alg, cache, val, err, iter)
6868
return all(component_bools)
6969
end
7070

71-
function run!(alg::ConvergenceChecker, cache, val, err, iter)
71+
function check_convergence!(alg::ConvergenceChecker, cache, val, err, iter)
7272
(; norm_condition, component_condition, condition_combiner, norm) = alg
7373
(; norm_cache, component_cache) = cache
7474
if isnothing(norm_condition)

src/solvers/imex_ark.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,13 @@ function step_u!(integrator, cache::IMEXARKCache)
130130
@. residual = temp + dt * a_imp[i, i] * residual - Ui
131131
end
132132
implicit_equation_jacobian! = (jacobian, Ui) -> T_imp!.Wfact(jacobian, Ui, p, dt * a_imp[i, i], t_imp)
133-
run!(newtons_method, newtons_method_cache, U[i], implicit_equation_residual!, implicit_equation_jacobian!)
133+
solve_newton!(
134+
newtons_method,
135+
newtons_method_cache,
136+
U[i],
137+
implicit_equation_residual!,
138+
implicit_equation_jacobian!,
139+
)
134140
end
135141

136142
# We do not need to DSS U[i] again because the implicit solve should

src/solvers/newtons_method.jl

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ struct ForwardDiffStepSize3 <: ForwardDiffStepSize end
125125
Computes the Jacobian-vector product `j(x[n]) * Δx[n]` for a Newton-Krylov
126126
method without directly using the Jacobian `j(x[n])`, and instead only using
127127
`x[n]`, `f(x[n])`, and other function evaluations `f(x′)`. This is done by
128-
calling `run!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f)`. The `jΔx` passed to
128+
calling `jvp!(::JacobianFreeJVP, cache, jΔx, Δx, x, f!, f)`. The `jΔx` passed to
129129
a Jacobian-free JVP is modified in-place. The `cache` can be obtained with
130130
`allocate_cache(::JacobianFreeJVP, x_prototype)`, where `x_prototype` is
131131
`similar` to `x` (and also to `Δx` and `f`).
@@ -146,7 +146,7 @@ end
146146

147147
allocate_cache(::ForwardDiffJVP, x_prototype) = (; x2 = similar(x_prototype), f2 = similar(x_prototype))
148148

149-
function run!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f)
149+
function jvp!(alg::ForwardDiffJVP, cache, jΔx, Δx, x, f!, f)
150150
(; default_step, step_adjustment) = alg
151151
(; x2, f2) = cache
152152
FT = eltype(x)
@@ -160,7 +160,7 @@ end
160160
ForcingTerm
161161
162162
Computes the value of `rtol[n]` for a Newton-Krylov method. This is done by
163-
calling `run!(::ForcingTerm, cache, f, n)`, which returns `rtol[n]`. The `cache`
163+
calling `get_rtol!(::ForcingTerm, cache, f, n)`, which returns `rtol[n]`. The `cache`
164164
can be obtained with `allocate_cache(::ForcingTerm, x_prototype)`, where
165165
`x_prototype` is `similar` to `f`.
166166
@@ -188,7 +188,7 @@ end
188188

189189
allocate_cache(::ConstantForcing, x_prototype) = (;)
190190

191-
function run!(alg::ConstantForcing, cache, f, n)
191+
function get_rtol!(alg::ConstantForcing, cache, f, n)
192192
FT = eltype(f)
193193
return FT(alg.rtol)
194194
end
@@ -230,7 +230,7 @@ function allocate_cache(::EisenstatWalkerForcing, x_prototype)
230230
return (; prev_norm_f = Ref{FT}(), prev_rtol = Ref{FT}())
231231
end
232232

233-
function run!(alg::EisenstatWalkerForcing, cache, f, n)
233+
function get_rtol!(alg::EisenstatWalkerForcing, cache, f, n)
234234
(; initial_rtol, γ, α, min_rtol_threshold, max_rtol) = alg
235235
(; prev_norm_f, prev_rtol) = cache
236236
FT = eltype(f)
@@ -256,7 +256,7 @@ end
256256
257257
Prints information about the Jacobian matrix `j` and the preconditioner `M` (if
258258
it is available) that are passed to a Krylov method. This is done by calling
259-
`run!(::KrylovMethodDebugger, cache, j, M)`. The `cache` can be obtained with
259+
`print_debug!(::KrylovMethodDebugger, cache, j, M)`. The `cache` can be obtained with
260260
`allocate_cache(::KrylovMethodDebugger, x_prototype)`, where `x_prototype` is
261261
`similar` to `x`.
262262
"""
@@ -284,7 +284,9 @@ function allocate_cache(::PrintConditionNumber, x_prototype)
284284
)
285285
end
286286

287-
function run!(::PrintConditionNumber, cache, j, M)
287+
print_debug!(::Nothing, cache, j, M) = nothing
288+
289+
function print_debug!(::PrintConditionNumber, cache, j, M)
288290
(; dense_vector, dense_j, dense_inv_M, dense_inv_M_j) = cache
289291
dense_matrix_from_operator!(dense_j, dense_vector, j)
290292
if M === I
@@ -335,7 +337,7 @@ end
335337
Finds an approximation `Δx[n] ≈ j(x[n]) \\ f(x[n])` for Newton's method such
336338
that `‖f(x[n]) - j(x[n]) * Δx[n]‖ ≤ rtol[n] * ‖f(x[n])‖`, where `rtol[n]` is the
337339
value of the forcing term on iteration `n`. This is done by calling
338-
`run!(::KrylovMethod, cache, Δx, x, f!, f, n, j = nothing)`, where `f` is
340+
`solve_krylov!(::KrylovMethod, cache, Δx, x, f!, f, n, j = nothing)`, where `f` is
339341
`f(x[n])` and, if it is specified, `j` is either `j(x[n])` or an approximation
340342
of `j(x[n])`. The `Δx` passed to a Krylov method is modified in-place. The
341343
`cache` can be obtained with `allocate_cache(::KrylovMethod, x_prototype)`,
@@ -347,7 +349,7 @@ This is primarily a wrapper for a `Krylov.KrylovSolver` from `Krylov.jl`. In
347349
`l = length(x_prototype)` and `Krylov.ktypeof(x_prototype)` is a subtype of
348350
`DenseVector` that can be used to store `x_prototype`. By default, the solver
349351
is a `Krylov.GmresSolver` with a Krylov subspace size of 20 (the default Krylov
350-
subspace size for this solver in `Krylov.jl`). In `run!`, the solver is run with
352+
subspace size for this solver in `Krylov.jl`). In `solve_krylov!`, the solver is run with
351353
`Krylov.solve!(solver, opj, f; M, ldiv, atol, rtol, verbose, solve_kwargs...)`.
352354
The solver's type can be changed by specifying a different value for `type`,
353355
though this value has to be wrapped in a `Val` to avoid runtime compilation.
@@ -419,20 +421,20 @@ function allocate_cache(alg::KrylovMethod, x_prototype)
419421
)
420422
end
421423

422-
function run!(alg::KrylovMethod, cache, Δx, x, f!, f, n, j = nothing)
424+
function solve_krylov!(alg::KrylovMethod, cache, Δx, x, f!, f, n, j = nothing)
423425
(; jacobian_free_jvp, forcing_term, solve_kwargs) = alg
424426
(; disable_preconditioner, verbose, debugger) = alg
425427
type = solver_type(alg)
426428
(; jacobian_free_jvp_cache, forcing_term_cache, solver, debugger_cache) = cache
427429
jΔx!(jΔx, Δx) =
428430
isnothing(jacobian_free_jvp) ? mul!(jΔx, j, Δx) :
429-
run!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f)
431+
jvp!(jacobian_free_jvp, jacobian_free_jvp_cache, jΔx, Δx, x, f!, f)
430432
opj = LinearOperator(eltype(x), length(x), length(x), false, false, jΔx!)
431433
M = disable_preconditioner || isnothing(j) || isnothing(jacobian_free_jvp) ? I : j
432-
run!(debugger, debugger_cache, opj, M)
434+
print_debug!(debugger, debugger_cache, opj, M)
433435
ldiv = true
434436
atol = zero(eltype(Δx))
435-
rtol = run!(forcing_term, forcing_term_cache, f, n)
437+
rtol = get_rtol!(forcing_term, forcing_term_cache, f, n)
436438
verbose = Int(verbose)
437439
Krylov.solve!(solver, opj, f; M, ldiv, atol, rtol, verbose, solve_kwargs...)
438440
iter = solver.stats.niter
@@ -466,7 +468,7 @@ end
466468
467469
Solves the equation `f(x) = 0`, using the Jacobian (or an approximation of the
468470
Jacobian) `j(x) = f'(x)` if it is available. This is done by calling
469-
`run!(::NewtonsMethod, cache, x, f!, j! = nothing)`, where `f!(f, x)` is a
471+
`solve_newton!(::NewtonsMethod, cache, x, f!, j! = nothing)`, where `f!(f, x)` is a
470472
function that sets `f(x)` in-place and, if it is specified, `j!(j, x)` is a
471473
function that sets `j(x)` in-place. The `x` passed to Newton's method is
472474
modified in-place, and its initial value is used as a starting guess for the
@@ -508,7 +510,7 @@ If `j(x)` changes sufficiently slowly, `update_j` may be changed from
508510
`UpdateEvery(NewNewtonIteration)` to some other `UpdateSignalHandler` that
509511
gets triggered less frequently, such as `UpdateEvery(NewNewtonSolve)`. This
510512
can be used to make the approximation `j(x[n]) ≈ j(x₀)`, where `x₀` is a
511-
previous value of `x[n]` (possibly even a value from a previous `run!` of
513+
previous value of `x[n]` (possibly even a value from a previous `solve_newton!` of
512514
Newton's method). When Newton's method uses such an approximation, it is called
513515
the "chord method".
514516
@@ -558,7 +560,7 @@ function allocate_cache(alg::NewtonsMethod, x_prototype, j_prototype = nothing)
558560
)
559561
end
560562

561-
function run!(alg::NewtonsMethod, cache, x, f!, j! = nothing)
563+
function solve_newton!(alg::NewtonsMethod, cache, x, f!, j! = nothing)
562564
(; max_iters, update_j, krylov_method, convergence_checker, verbose) = alg
563565
(; update_j_cache, krylov_method_cache, convergence_checker_cache) = cache
564566
(; Δx, f, j) = cache
@@ -581,13 +583,13 @@ function run!(alg::NewtonsMethod, cache, x, f!, j! = nothing)
581583
ldiv!(Δx, j, f)
582584
end
583585
else
584-
run!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, j)
586+
solve_krylov!(krylov_method, krylov_method_cache, Δx, x, f!, f, n, j)
585587
end
586588
verbose && @info "Newton iteration $n: ‖x‖ = $(norm(x)), ‖Δx‖ = $(norm(Δx))"
587589

588590
# Check for convergence if necessary.
589591
if !isnothing(convergence_checker)
590-
run!(convergence_checker, convergence_checker_cache, x, Δx, n) && break
592+
check_convergence!(convergence_checker, convergence_checker_cache, x, Δx, n) && break
591593
n == max_iters && @warn "Newton's method did not converge within $n iterations"
592594
end
593595
end

test/test_convergence_checker.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ClimaTimeSteppers, Test
2+
import ClimaTimeSteppers as CTS
23

34
@testset "ConvergenceChecker" begin
45
val_func(iter) = [60.0, -80.0]
@@ -16,9 +17,9 @@ using ClimaTimeSteppers, Test
1617
cache = allocate_cache(checker, val_func(0))
1718
for (err_func, last_iter) in ((err_func1, last_iter1), (err_func2, last_iter2))
1819
for iter in 0:(last_iter - 1)
19-
run!(checker, cache, val_func(iter), err_func(iter), iter) && return false
20+
CTS.check_convergence!(checker, cache, val_func(iter), err_func(iter), iter) && return false
2021
end
21-
run!(checker, cache, val_func(last_iter), err_func(last_iter), last_iter) || return false
22+
CTS.check_convergence!(checker, cache, val_func(last_iter), err_func(last_iter), last_iter) || return false
2223
end
2324
return true
2425
end

test/test_newtons_method.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using ClimaTimeSteppers, LinearAlgebra, Random, Test
2+
import ClimaTimeSteppers as CTS
23

34
function linear_equation(FT, n)
45
rng = MersenneTwister(1)
@@ -60,7 +61,7 @@ end
6061
x = copy(x_init)
6162
j_prototype = similar(x, length(x), length(x))
6263
cache = allocate_cache(alg, x, use_j ? j_prototype : nothing)
63-
run!(alg, cache, x, f!, use_j ? j! : nothing)
64+
CTS.solve_newton!(alg, cache, x, f!, use_j ? j! : nothing)
6465
@test norm(x .- x_exact) / norm(x_exact) < rtol
6566
end
6667
end

0 commit comments

Comments
 (0)