Skip to content

Commit 48bd3af

Browse files
Re-design run for signal handlers
1 parent 912e2fd commit 48bd3af

File tree

3 files changed

+33
-22
lines changed

3 files changed

+33
-22
lines changed

src/ClimaTimeSteppers.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ array_device(::CuArray) = CUDADevice()
5656
realview(x::Union{Array, SArray, MArray}) = x
5757
realview(x::CuArray) = x
5858

59-
export allocate_cache, run!
60-
6159
import DiffEqBase, SciMLBase, LinearAlgebra, DiffEqCallbacks, Krylov
6260

6361
include("sparse_containers.jl")

src/solvers/newtons_method.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,9 @@ function solve_newton!(alg::NewtonsMethod, cache, x, f!, j! = nothing)
564564
(; max_iters, update_j, krylov_method, convergence_checker, verbose) = alg
565565
(; update_j_cache, krylov_method_cache, convergence_checker_cache) = cache
566566
(; Δx, f, j) = cache
567-
isnothing(j) || run!(update_j, update_j_cache, NewNewtonSolve(), j!, j, x)
567+
if (!isnothing(j)) && needs_update!(update_j, update_j_cache, NewNewtonSolve())
568+
j!(j, x)
569+
end
568570
for n in 0:max_iters
569571
# Update x[n] with Δx[n - 1], and exit the loop if Δx[n] is not needed.
570572
n > 0 && (x .-= Δx)
@@ -574,7 +576,9 @@ function solve_newton!(alg::NewtonsMethod, cache, x, f!, j! = nothing)
574576
end
575577

576578
# Compute Δx[n].
577-
isnothing(j) || run!(update_j, update_j_cache, NewNewtonIteration(), j!, j, x)
579+
if (!isnothing(j)) && needs_update!(update_j, update_j_cache, NewNewtonIteration())
580+
j!(j, x)
581+
end
578582
f!(f, x)
579583
if isnothing(krylov_method)
580584
if j isa DenseMatrix
@@ -598,5 +602,7 @@ end
598602
function update!(alg::NewtonsMethod, cache, signal::UpdateSignal, j!)
599603
(; update_j) = alg
600604
(; update_j_cache, j) = cache
601-
isnothing(j) || run!(update_j, update_j_cache, signal, j!, j)
605+
if (!isnothing(j)) && needs_update!(update_j, update_j_cache, signal)
606+
j!(j)
607+
end
602608
end

src/solvers/update_signal_handler.jl

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@ abstract type UpdateSignal end
1212
"""
1313
UpdateSignalHandler
1414
15-
Updates a value upon receiving an appropriate `UpdateSignal`. This is done by
16-
calling `run!(::UpdateSignalHandler, cache, ::UpdateSignal, f!, args...)`, where
17-
`f!` is function such that `f!(args...)` modifies the desired value in-place.
15+
A boolean indicating if updates a value upon receiving an appropriate
16+
`UpdateSignal`. This is done by calling
17+
`needs_update!(::UpdateSignalHandler, cache, ::UpdateSignal)`.
18+
1819
The `cache` can be obtained with `allocate_cache(::UpdateSignalHandler, FT)`,
1920
where `FT` is the floating-point type of the integrator.
2021
"""
2122
abstract type UpdateSignalHandler end
2223

23-
run!(::UpdateSignalHandler, cache, ::UpdateSignal, f!, args...) = nothing
24+
needs_update!(::UpdateSignalHandler, cache, ::UpdateSignal) = false
2425

2526
"""
2627
NewTimeStep(t)
@@ -34,7 +35,7 @@ end
3435
"""
3536
NewNewtonSolve()
3637
37-
The signal for a new `run!` of Newton's method, which occurs on every implicit
38+
The signal for a new `needs_update!` of Newton's method, which occurs on every implicit
3839
Runge-Kutta stage of the integrator.
3940
"""
4041
struct NewNewtonSolve <: UpdateSignal end
@@ -49,24 +50,24 @@ struct NewNewtonIteration <: UpdateSignal end
4950
"""
5051
UpdateEvery(update_signal_type)
5152
52-
An `UpdateSignalHandler` that performs the update whenever it is `run!` with an
53+
An `UpdateSignalHandler` that performs the update whenever it is `needs_update!` with an
5354
`UpdateSignal` of type `update_signal_type`.
5455
"""
5556
struct UpdateEvery{U <: UpdateSignal} <: UpdateSignalHandler end
5657
UpdateEvery(::Type{U}) where {U} = UpdateEvery{U}()
5758

5859
allocate_cache(::UpdateEvery, _) = nothing
5960

60-
run!(alg::UpdateEvery{U}, cache, ::U, f!, args...) where {U <: UpdateSignal} = f!(args...)
61+
needs_update!(alg::UpdateEvery{U}, cache, ::U) where {U <: UpdateSignal} = true
6162

6263
"""
6364
UpdateEveryN(n, update_signal_type, reset_signal_type = Nothing)
6465
65-
An `UpdateSignalHandler` that performs the update every `n`-th time it is `run!`
66+
An `UpdateSignalHandler` that performs the update every `n`-th time it is `needs_update!`
6667
with an `UpdateSignal` of type `update_signal_type`. If `reset_signal_type` is
6768
specified, then the counter (which gets incremented from 0 to `n` and then gets
6869
reset to 0 when it is time to perform another update) is reset to 0 whenever the
69-
signal handler is `run!` with an `UpdateSignal` of type `reset_signal_type`.
70+
signal handler is `needs_update!` with an `UpdateSignal` of type `reset_signal_type`.
7071
"""
7172
struct UpdateEveryN{U <: UpdateSignal, R <: Union{Nothing, UpdateSignal}} <: UpdateSignalHandler
7273
n::Int
@@ -75,26 +76,30 @@ UpdateEveryN(n, ::Type{U}, ::Type{R} = Nothing) where {U, R} = UpdateEveryN{U, R
7576

7677
allocate_cache(::UpdateEveryN, _) = (; counter = Ref(0))
7778

78-
function run!(alg::UpdateEveryN{U}, cache, ::U, f!, args...) where {U}
79+
function needs_update!(alg::UpdateEveryN{U}, cache, ::U) where {U <: UpdateSignal}
7980
(; n) = alg
8081
(; counter) = cache
81-
if counter[] == 0
82-
f!(args...)
83-
end
82+
result = counter[] == 0
8483
counter[] += 1
8584
if counter[] == n
8685
counter[] = 0
8786
end
87+
return result
8888
end
89-
function run!(alg::UpdateEveryN{<:Any, R}, cache, ::R, f!, args...) where {R}
89+
function needs_update!(alg::UpdateEveryN{U, R}, cache, ::R) where {U, R <: UpdateSignal}
9090
(; counter) = cache
9191
counter[] = 0
92+
return false
9293
end
9394

95+
# Account for method ambiguitiy:
96+
needs_update!(::UpdateEveryN{U, U}, cache, ::U) where {U <: UpdateSignal} =
97+
error("Reset and update signal types cannot be the same.")
98+
9499
"""
95100
UpdateEveryDt(dt)
96101
97-
An `UpdateSignalHandler` that performs the update whenever it is `run!` with an
102+
An `UpdateSignalHandler` that performs the update whenever it is `needs_update!` with an
98103
`UpdateSignal` of type `NewTimeStep` and the difference between the current time
99104
and the previous update time is no less than `dt`.
100105
"""
@@ -105,13 +110,15 @@ end
105110
# TODO: This assumes that typeof(t) == FT, which might not always be correct.
106111
allocate_cache(alg::UpdateEveryDt, ::Type{FT}) where {FT} = (; is_first_t = Ref(true), prev_update_t = Ref{FT}())
107112

108-
function run!(alg::UpdateEveryDt, cache, signal::NewTimeStep, f!, args...)
113+
function needs_update!(alg::UpdateEveryDt, cache, signal::NewTimeStep)
109114
(; dt) = alg
110115
(; is_first_t, prev_update_t) = cache
111116
(; t) = signal
117+
result = false
112118
if is_first_t[] || abs(t - prev_update_t[]) >= dt
113-
f!(args...)
119+
result = true
114120
is_first_t[] = false
115121
prev_update_t[] = t
116122
end
123+
return result
117124
end

0 commit comments

Comments
 (0)