@@ -12,15 +12,16 @@ abstract type UpdateSignal end
1212"""
1313 UpdateSignalHandler
1414
15- Updates a value upon receiving an appropriate `UpdateSignal`. This is done by
16- calling `run!(::UpdateSignalHandler, cache, ::UpdateSignal, f!, args...)`, where
17- `f!` is function such that `f!(args...)` modifies the desired value in-place.
15+ A boolean indicating if updates a value upon receiving an appropriate
16+ `UpdateSignal`. This is done by calling
17+ `needs_update!(::UpdateSignalHandler, cache, ::UpdateSignal)`.
18+
1819The `cache` can be obtained with `allocate_cache(::UpdateSignalHandler, FT)`,
1920where `FT` is the floating-point type of the integrator.
2021"""
2122abstract type UpdateSignalHandler end
2223
23- run ! (:: UpdateSignalHandler , cache, :: UpdateSignal , f!, args ... ) = nothing
24+ needs_update ! (:: UpdateSignalHandler , cache, :: UpdateSignal ) = false
2425
2526"""
2627 NewTimeStep(t)
3435"""
3536 NewNewtonSolve()
3637
37- The signal for a new `run !` of Newton's method, which occurs on every implicit
38+ The signal for a new `needs_update !` of Newton's method, which occurs on every implicit
3839Runge-Kutta stage of the integrator.
3940"""
4041struct NewNewtonSolve <: UpdateSignal end
@@ -49,24 +50,24 @@ struct NewNewtonIteration <: UpdateSignal end
4950"""
5051 UpdateEvery(update_signal_type)
5152
52- An `UpdateSignalHandler` that performs the update whenever it is `run !` with an
53+ An `UpdateSignalHandler` that performs the update whenever it is `needs_update !` with an
5354`UpdateSignal` of type `update_signal_type`.
5455"""
5556struct UpdateEvery{U <: UpdateSignal } <: UpdateSignalHandler end
5657UpdateEvery (:: Type{U} ) where {U} = UpdateEvery {U} ()
5758
5859allocate_cache (:: UpdateEvery , _) = nothing
5960
60- run ! (alg:: UpdateEvery{U} , cache, :: U , f!, args ... ) where {U <: UpdateSignal } = f! (args ... )
61+ needs_update ! (alg:: UpdateEvery{U} , cache, :: U ) where {U <: UpdateSignal } = true
6162
6263"""
6364 UpdateEveryN(n, update_signal_type, reset_signal_type = Nothing)
6465
65- An `UpdateSignalHandler` that performs the update every `n`-th time it is `run !`
66+ An `UpdateSignalHandler` that performs the update every `n`-th time it is `needs_update !`
6667with an `UpdateSignal` of type `update_signal_type`. If `reset_signal_type` is
6768specified, then the counter (which gets incremented from 0 to `n` and then gets
6869reset to 0 when it is time to perform another update) is reset to 0 whenever the
69- signal handler is `run !` with an `UpdateSignal` of type `reset_signal_type`.
70+ signal handler is `needs_update !` with an `UpdateSignal` of type `reset_signal_type`.
7071"""
7172struct UpdateEveryN{U <: UpdateSignal , R <: Union{Nothing, UpdateSignal} } <: UpdateSignalHandler
7273 n:: Int
@@ -75,26 +76,30 @@ UpdateEveryN(n, ::Type{U}, ::Type{R} = Nothing) where {U, R} = UpdateEveryN{U, R
7576
7677allocate_cache (:: UpdateEveryN , _) = (; counter = Ref (0 ))
7778
78- function run ! (alg:: UpdateEveryN{U} , cache, :: U , f!, args ... ) where {U}
79+ function needs_update ! (alg:: UpdateEveryN{U} , cache, :: U ) where {U <: UpdateSignal }
7980 (; n) = alg
8081 (; counter) = cache
81- if counter[] == 0
82- f! (args... )
83- end
82+ result = counter[] == 0
8483 counter[] += 1
8584 if counter[] == n
8685 counter[] = 0
8786 end
87+ return result
8888end
89- function run ! (alg:: UpdateEveryN{<:Any , R} , cache, :: R , f!, args ... ) where {R }
89+ function needs_update ! (alg:: UpdateEveryN{U , R} , cache, :: R ) where {U, R <: UpdateSignal }
9090 (; counter) = cache
9191 counter[] = 0
92+ return false
9293end
9394
95+ # Account for method ambiguitiy:
96+ needs_update! (:: UpdateEveryN{U, U} , cache, :: U ) where {U <: UpdateSignal } =
97+ error (" Reset and update signal types cannot be the same." )
98+
9499"""
95100 UpdateEveryDt(dt)
96101
97- An `UpdateSignalHandler` that performs the update whenever it is `run !` with an
102+ An `UpdateSignalHandler` that performs the update whenever it is `needs_update !` with an
98103`UpdateSignal` of type `NewTimeStep` and the difference between the current time
99104and the previous update time is no less than `dt`.
100105"""
@@ -105,13 +110,15 @@ end
105110# TODO : This assumes that typeof(t) == FT, which might not always be correct.
106111allocate_cache (alg:: UpdateEveryDt , :: Type{FT} ) where {FT} = (; is_first_t = Ref (true ), prev_update_t = Ref {FT} ())
107112
108- function run ! (alg:: UpdateEveryDt , cache, signal:: NewTimeStep , f!, args ... )
113+ function needs_update ! (alg:: UpdateEveryDt , cache, signal:: NewTimeStep )
109114 (; dt) = alg
110115 (; is_first_t, prev_update_t) = cache
111116 (; t) = signal
117+ result = false
112118 if is_first_t[] || abs (t - prev_update_t[]) >= dt
113- f! (args ... )
119+ result = true
114120 is_first_t[] = false
115121 prev_update_t[] = t
116122 end
123+ return result
117124end
0 commit comments