Skip to content

Commit 506e01c

Browse files
Merge #101
101: Clean up update_j and KrylovMethod r=charleskawczynski a=dennisYatunin Co-authored-by: Dennis Yatunin <[email protected]>
2 parents d03c7c0 + 09eaad6 commit 506e01c

File tree

4 files changed

+126
-71
lines changed

4 files changed

+126
-71
lines changed

docs/src/newtons_method.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@ MultipleConditions
4444
UpdateSignalHandler
4545
UpdateEvery
4646
UpdateEveryN
47+
UpdateEveryDt
4748
UpdateSignal
48-
NewStep
49+
NewTimeStep
4950
NewNewtonSolve
5051
NewNewtonIteration
5152
```

src/solvers/imex_ark.jl

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,15 @@ function cache(
220220
i -> Symbol(:f, χ, :_, i) => similar(u),
221221
filter(i -> save_tendency(i, a), i_range(a)),
222222
)
223+
γs = unique(filter(!iszero, diag(as[2])))
224+
γ = length(γs) == 1 ? γs[1] : nothing
223225
u = prob.u0
224226
Uis = map(
225227
i -> Symbol(:U, i) => similar(u),
226228
filter(i -> !(i in u_alias_is(as[1], as[2])), i_range(as[1])[1:end - 1])
227229
)
228230
_cache = NamedTuple((
231+
=> γ,
229232
:U_temp => similar(u),
230233
Uis...,
231234
f_cache(:exp, as[1], typeof(prob.f.f2))...,
@@ -253,6 +256,14 @@ struct ImplicitErrorJacobian{W, P, T}
253256
t::T
254257
Δt::T
255258
end
259+
struct FirstImplicitErrorJacobian{W, U, P, T, Γ}
260+
Wfact!::W
261+
u::U
262+
p::P
263+
t::T
264+
Δt::T
265+
γ::Γ
266+
end
256267

257268
(implicit_error::ImplicitError)(f, u) =
258269
implicit_error(f, u, implicit_error.ode_f!)
@@ -266,6 +277,14 @@ function ((; û, p, t, Δt)::ImplicitError)(f, u, ode_f!)
266277
f .=.+ Δt .* f .- u
267278
end
268279
((; Wfact!, p, t, Δt)::ImplicitErrorJacobian)(j, u) = Wfact!(j, u, p, Δt, t)
280+
function ((; Wfact!, u, p, t, Δt, γ)::FirstImplicitErrorJacobian)(j)
281+
isnothing(γ) &&
282+
error(
283+
"Cannot compute implicit error Jacobian for timestep becasue a_imp \
284+
does not have a unique value of γ. Try using a different tableau."
285+
)
286+
Wfact!(j, u, p, Δt * typeof(Δt)(γ), t)
287+
end
269288

270289
function step_u_expr(
271290
::Type{<:IMEXARKCache{as, cs}},
@@ -296,13 +315,11 @@ function step_u_expr(
296315
(; f1, f2) = f;
297316
(; newtons_method) = alg;
298317
(; _cache, newtons_method_cache) = cache;
299-
isnothing(f1.Wfact) || run!(
300-
newtons_method.update_j,
301-
newtons_method_cache.update_j_cache,
302-
NewStep(),
303-
ImplicitErrorJacobian(f1.Wfact, p, t, dt * $(FT(as[1][end, end]))),
304-
newtons_method_cache.j,
305-
u,
318+
isnothing(f1.Wfact) || update!(
319+
newtons_method,
320+
newtons_method_cache,
321+
NewTimeStep(t),
322+
FirstImplicitErrorJacobian(f1.Wfact, u, p, t, dt, _cache.γ),
306323
);
307324
)
308325

@@ -436,12 +453,15 @@ function not_generated_cache(
436453
filter(i -> save_tendency(i, a), i_range(a)),
437454
)
438455

456+
γs = unique(filter(!iszero, diag(as[2])))
457+
γ = length(γs) == 1 ? γs[1] : nothing
439458
u = prob.u0
440459
Uis = map(
441460
i -> Symbol(:U, i) => similar(u),
442461
filter(i -> !(i in u_alias_is(as[1], as[2])), i_range(as[1])[1:end - 1])
443462
)
444463
_cache = NamedTuple((
464+
=> γ,
445465
:U_temp => similar(u),
446466
Uis...,
447467
f_cache(:exp, as[1], typeof(prob.f.f2))...,
@@ -481,13 +501,11 @@ function not_generated_step_u!(integrator, cache::IMEXARKCache{as, cs}) where {a
481501
f_types = (typeof(f2), typeof(f1))
482502
(; u_alias_is_, first_i_s, new_js_s, js_to_save_s, has_implicit_step_s, save_tendency_s, old_js_s) = _cache
483503

484-
isnothing(f1.Wfact) || run!(
485-
newtons_method.update_j,
486-
newtons_method_cache.update_j_cache,
487-
NewStep(),
488-
ImplicitErrorJacobian(f1.Wfact, p, t, dt * FT(as[1][end, end])),
489-
newtons_method_cache.j,
490-
u,
504+
isnothing(f1.Wfact) || update!(
505+
newtons_method,
506+
newtons_method_cache,
507+
NewTimeStep(t),
508+
FirstImplicitErrorJacobian(f1.Wfact, u, p, t, dt, _cache.γ),
491509
)
492510

493511
function Δu_broadcast(i, j, χ, a, f_type, first_i_)

src/solvers/newtons_method.jl

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -325,9 +325,9 @@ end
325325

326326
"""
327327
KrylovMethod(;
328+
type = Val(Krylov.GmresSolver),
328329
jacobian_free_jvp = nothing,
329330
forcing_term = ConstantForcing(0),
330-
type = Val(Krylov.GmresSolver),
331331
args = (20,),
332332
kwargs = (;),
333333
solve_kwargs = (;),
@@ -347,13 +347,14 @@ where `x_prototype` is `similar` to `x` (and also to `Δx` and `f`).
347347
348348
This is primarily a wrapper for a `Krylov.KrylovSolver` from `Krylov.jl`. In
349349
`allocate_cache`, the solver is constructed with
350-
`solver = type(l, l, args..., Krylov.ktypeof(x_prototype); kwargs...)` (note
351-
that `type` must be passed through in a `Val` struct), where
350+
`solver = type(l, l, args..., Krylov.ktypeof(x_prototype); kwargs...)`, where
352351
`l = length(x_prototype)` and `Krylov.ktypeof(x_prototype)` is a subtype of
353352
`DenseVector` that can be used to store `x_prototype`. By default, the solver
354353
is a `Krylov.GmresSolver` with a Krylov subspace size of 20 (the default Krylov
355354
subspace size for this solver in `Krylov.jl`). In `run!`, the solver is run with
356355
`Krylov.solve!(solver, opj, f; M, ldiv, atol, rtol, verbose, solve_kwargs...)`.
356+
The solver's type can be changed by specifying a different value for `type`,
357+
though this value has to be wrapped in a `Val` to avoid runtime compilation.
357358
358359
In the call to `Krylov.solve!`, `opj` is a `LinearOperator` that represents
359360
`j(x[n])`, which the solver uses by evaluating `mul!(jΔx, opj, Δx)`. If a
@@ -388,7 +389,7 @@ each iteration of the Krylov method. If a debugger is specified, it is run
388389
before the call to `Kyrlov.solve!`.
389390
"""
390391
Base.@kwdef struct KrylovMethod{
391-
T <: Val,
392+
T <: Val{<:Krylov.KrylovSolver},
392393
J <: Union{Nothing, JacobianFreeJVP},
393394
F <: ForcingTerm,
394395
A <: Tuple,
@@ -412,7 +413,6 @@ solver_type(::KrylovMethod{Val{T}}) where {T} = T
412413
function allocate_cache(alg::KrylovMethod, x_prototype)
413414
(; jacobian_free_jvp, forcing_term, args, kwargs, debugger) = alg
414415
type = solver_type(alg)
415-
@assert type isa Type{<:Krylov.KrylovSolver}
416416
l = length(x_prototype)
417417
return (;
418418
jacobian_free_jvp_cache = isnothing(jacobian_free_jvp) ? nothing :
@@ -466,7 +466,7 @@ end
466466
"""
467467
NewtonsMethod(;
468468
max_iters = 1,
469-
update_j = UpdateEvery(NewNewtonIteration()),
469+
update_j = UpdateEvery(NewNewtonIteration),
470470
krylov_method = nothing,
471471
convergence_checker = nothing,
472472
verbose = false,
@@ -512,11 +512,23 @@ for its preconditioners, so, since the value computed with `j!` is used as a
512512
preconditioner in Krylov methods with a Jacobian-free JVP, using such a Krylov
513513
method requires specifying a `j_prototype` that can be passed to `ldiv!`.
514514
515-
If `j(x)` changes sufficiently slowly, `update_j` can be changed from
516-
`UpdateEvery(NewNewtonIteration())` to some other `UpdateSignalHandler` in order
517-
to make the approximation `j(x[n]) ≈ j(x₀)`, where `x₀` is a previous value of
518-
`x[n]` (this could even be a value from a previous `run!` of Newton's method).
519-
When Newton's method uses this approximation, it is called the "chord method".
515+
If `j(x)` changes sufficiently slowly, `update_j` may be changed from
516+
`UpdateEvery(NewNewtonIteration)` to some other `UpdateSignalHandler` that
517+
gets triggered less frequently, such as `UpdateEvery(NewNewtonSolve)`. This
518+
can be used to make the approximation `j(x[n]) ≈ j(x₀)`, where `x₀` is a
519+
previous value of `x[n]` (possibly even a value from a previous `run!` of
520+
Newton's method). When Newton's method uses such an approximation, it is called
521+
the "chord method".
522+
523+
In addition, `update_j` can be set to an `UpdateSignalHandler` that gets
524+
triggered by signals that originate outside of Newton's method, such as
525+
`UpdateEvery(NewTimeStep)`. It is possible to send any signal for updating `j`
526+
to Newton's method while it is not running by calling
527+
`update!(::NewtonsMethod, cache, ::UpdateSignal, j!)`, where in this case
528+
`j!(j)` is a function that sets `j` in-place without any dependence on `x`
529+
(since `x` is not necessarily defined while Newton's method is not running, this
530+
version of `j!` does not take `x` as an argument). This can be used to make the
531+
approximation `j(x[n]) ≈ j₀`, where `j₀` can have an arbitrary value.
520532
521533
If a convergence checker is provided, it gets used to determine whether to stop
522534
iterating on iteration `n` based on the value `x[n]` and its error `Δx[n]`;
@@ -534,7 +546,7 @@ Base.@kwdef struct NewtonsMethod{
534546
C <: Union{Nothing, ConvergenceChecker},
535547
}
536548
max_iters::Int = 1
537-
update_j::U = UpdateEvery(NewNewtonIteration())
549+
update_j::U = UpdateEvery(NewNewtonIteration)
538550
krylov_method::K = nothing
539551
convergence_checker::C = nothing
540552
verbose::Bool = false
@@ -547,7 +559,7 @@ function allocate_cache(alg::NewtonsMethod, x_prototype, j_prototype = nothing)
547559
(isnothing(krylov_method) || isnothing(krylov_method.jacobian_free_jvp))
548560
)
549561
return (;
550-
update_j_cache = allocate_cache(update_j),
562+
update_j_cache = allocate_cache(update_j, eltype(x_prototype)),
551563
krylov_method_cache = isnothing(krylov_method) ? nothing :
552564
allocate_cache(krylov_method, x_prototype),
553565
convergence_checker_cache = isnothing(convergence_checker) ? nothing :
@@ -596,3 +608,9 @@ function run!(alg::NewtonsMethod, cache, x, f!, j! = nothing)
596608
end
597609
end
598610
end
611+
612+
function update!(alg::NewtonsMethod, cache, signal::UpdateSignal, j!)
613+
(; update_j) = alg
614+
(; update_j_cache, j) = cache
615+
isnothing(j) || run!(update_j, update_j_cache, signal, j!, j)
616+
end
Lines changed: 61 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
export UpdateSignal, NewStep, NewNewtonSolve, NewNewtonIteration
2-
export UpdateSignalHandler, UpdateEvery, UpdateEveryN
1+
export UpdateSignal, NewTimeStep, NewNewtonSolve, NewNewtonIteration
2+
export UpdateSignalHandler, UpdateEvery, UpdateEveryN, UpdateEveryDt
33

44
"""
55
UpdateSignal
@@ -10,11 +10,13 @@ operation is performed.
1010
abstract type UpdateSignal end
1111

1212
"""
13-
NewStep()
13+
NewTimeStep(t)
1414
15-
The signal for a new time step.
15+
The signal for a new time step at time `t`.
1616
"""
17-
struct NewStep <: UpdateSignal end
17+
struct NewTimeStep{T} <: UpdateSignal
18+
t::T
19+
end
1820

1921
"""
2022
NewNewtonSolve()
@@ -37,62 +39,78 @@ struct NewNewtonIteration <: UpdateSignal end
3739
Updates a value upon receiving an appropriate `UpdateSignal`. This is done by
3840
calling `run!(::UpdateSignalHandler, cache, ::UpdateSignal, f!, args...)`, where
3941
`f!` is function such that `f!(args...)` modifies the desired value in-place.
40-
The `cache` can be obtained with `allocate_cache(::UpdateSignalHandler)`.
42+
The `cache` can be obtained with `allocate_cache(::UpdateSignalHandler, FT)`,
43+
where `FT` is the floating-point type of the integrator.
4144
"""
4245
abstract type UpdateSignalHandler end
4346

4447
"""
45-
UpdateEvery(update_signal)
48+
UpdateEvery(update_signal_type)
4649
47-
An `UpdateSignalHandler` that executes the update every time it is `run!` with
48-
`update_signal`.
50+
An `UpdateSignalHandler` that performs the update whenever it is `run!` with an
51+
`UpdateSignal` of type `update_signal_type`.
4952
"""
50-
struct UpdateEvery{U <: UpdateSignal} <: UpdateSignalHandler
51-
update_signal::U
52-
end
53+
struct UpdateEvery{U <: UpdateSignal} <: UpdateSignalHandler end
54+
UpdateEvery(::Type{U}) where {U} = UpdateEvery{U}()
5355

54-
allocate_cache(::UpdateSignalHandler) = (;)
55-
56-
function run!(alg::UpdateEvery{U}, cache, ::U, f!, args...) where {
57-
U <: UpdateSignal,
58-
}
59-
f!(args...)
60-
return true
61-
end
56+
run!(alg::UpdateEvery{U}, cache, ::U, f!, args...) where {U} = f!(args...)
6257

6358
"""
64-
UpdateEveryN(update_signal, n, reset_n_signal = nothing)
59+
UpdateEveryN(n, update_signal_type, reset_signal_type = Nothing)
6560
66-
An `UpdateSignalHandler` that executes the update every `n`-th time it is `run!`
67-
with `update_signal`. If `reset_n_signal` is specified, then the value of `n` is
68-
reset to 0 every time the signal handler is `run!` with `reset_n_signal`.
61+
An `UpdateSignalHandler` that performs the update every `n`-th time it is `run!`
62+
with an `UpdateSignal` of type `update_signal_type`. If `reset_signal_type` is
63+
specified, then the counter (which gets incremented from 0 to `n` and then gets
64+
reset to 0 when it is time to perform another update) is reset to 0 whenever the
65+
signal handler is `run!` with an `UpdateSignal` of type `reset_signal_type`.
6966
"""
7067
struct UpdateEveryN{U <: UpdateSignal, R <: Union{Nothing, UpdateSignal}} <:
7168
UpdateSignalHandler
72-
update_signal::U
7369
n::Int
74-
reset_n_signal::R
7570
end
76-
UpdateEveryN(update_signal, n, reset_n_signal = nothing) =
77-
UpdateEveryN(update_signal, n, reset_n_signal)
71+
UpdateEveryN(n, ::Type{U}, ::Type{R} = Nothing) where {U, R} =
72+
UpdateEveryN{U, R}(n)
7873

79-
allocate_cache(::UpdateEveryN) = (; n = Ref(0))
74+
allocate_cache(::UpdateEveryN, _) = (; counter = Ref(0))
8075

81-
function run!(alg::UpdateEveryN{U}, cache, ::U, f!, args...) where {
82-
U <: UpdateSignal,
83-
}
84-
cache.n[] += 1
85-
if cache.n[] == alg.n
76+
function run!(alg::UpdateEveryN{U}, cache, ::U, f!, args...) where {U}
77+
(; n) = alg
78+
(; counter) = cache
79+
if counter[] == 0
8680
f!(args...)
87-
cache.n[] = 0
88-
return true
8981
end
90-
return false
82+
counter[] += 1
83+
if counter[] == n
84+
counter[] = 0
85+
end
86+
end
87+
function run!(alg::UpdateEveryN{<:Any, R}, cache, ::R, f!, args...) where {R}
88+
(; counter) = cache
89+
counter[] = 0
9190
end
92-
function run!(alg::UpdateEveryN{U, R}, cache, ::R, f!, args...) where {
93-
U,
94-
R <: UpdateSignal,
95-
}
96-
cache.n[] = 0
97-
return false
91+
92+
"""
93+
UpdateEveryDt(dt)
94+
95+
An `UpdateSignalHandler` that performs the update whenever it is `run!` with an
96+
`UpdateSignal` of type `NewTimeStep` and the difference between the current time
97+
and the previous update time is no less than `dt`.
98+
"""
99+
struct UpdateEveryDt{T} <: UpdateSignalHandler
100+
dt::T
101+
end
102+
103+
# TODO: This assumes that typeof(t) == FT, which might not always be correct.
104+
allocate_cache(alg::UpdateEveryDt, ::Type{FT}) where {FT} =
105+
(; is_first_t = Ref(true), prev_update_t = Ref{FT}())
106+
107+
function run!(alg::UpdateEveryDt, cache, signal::NewTimeStep, f!, args...)
108+
(; dt) = alg
109+
(; is_first_t, prev_update_t) = cache
110+
(; t) = signal
111+
if is_first_t[] || abs(t - prev_update_t[]) >= dt
112+
f!(args...)
113+
is_first_t[] = false
114+
prev_update_t[] = t
115+
end
98116
end

0 commit comments

Comments
 (0)