Skip to content

Commit 8c36569

Browse files
Remove generic fallback, add abstract type methods
1 parent 048ed07 commit 8c36569

File tree

6 files changed

+19
-37
lines changed

6 files changed

+19
-37
lines changed

docs/src/algorithms.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ CurrentModule = ClimaTimeSteppers
77
## Interface and OrdinaryDiffEq compatibility
88

99
```@docs
10-
allocate_cache
11-
run!
1210
ForwardEulerODEFunction
1311
```
1412

src/ClimaTimeSteppers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ array_device(::CuArray) = CUDADevice()
5656
realview(x::Union{Array, SArray, MArray}) = x
5757
realview(x::CuArray) = x
5858

59+
export allocate_cache, run!
5960

6061
import DiffEqBase, SciMLBase, LinearAlgebra, DiffEqCallbacks, Krylov
6162

6263
include("sparse_containers.jl")
6364
include("functions.jl")
6465
include("operators.jl")
65-
include("algorithms.jl")
6666

6767
abstract type DistributedODEAlgorithm <: DiffEqBase.AbstractODEAlgorithm end
6868

src/algorithms.jl

Lines changed: 0 additions & 20 deletions
This file was deleted.

src/solvers/update_signal_handler.jl

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,19 @@ operation is performed.
99
"""
1010
abstract type UpdateSignal end
1111

12+
"""
13+
UpdateSignalHandler
14+
15+
Updates a value upon receiving an appropriate `UpdateSignal`. This is done by
16+
calling `run!(::UpdateSignalHandler, cache, ::UpdateSignal, f!, args...)`, where
17+
`f!` is function such that `f!(args...)` modifies the desired value in-place.
18+
The `cache` can be obtained with `allocate_cache(::UpdateSignalHandler, FT)`,
19+
where `FT` is the floating-point type of the integrator.
20+
"""
21+
abstract type UpdateSignalHandler end
22+
23+
run!(::UpdateSignalHandler, cache, ::UpdateSignal, f!, args...) = nothing
24+
1225
"""
1326
NewTimeStep(t)
1427
@@ -33,17 +46,6 @@ The signal for a new iteration of Newton's method.
3346
"""
3447
struct NewNewtonIteration <: UpdateSignal end
3548

36-
"""
37-
UpdateSignalHandler
38-
39-
Updates a value upon receiving an appropriate `UpdateSignal`. This is done by
40-
calling `run!(::UpdateSignalHandler, cache, ::UpdateSignal, f!, args...)`, where
41-
`f!` is function such that `f!(args...)` modifies the desired value in-place.
42-
The `cache` can be obtained with `allocate_cache(::UpdateSignalHandler, FT)`,
43-
where `FT` is the floating-point type of the integrator.
44-
"""
45-
abstract type UpdateSignalHandler end
46-
4749
"""
4850
UpdateEvery(update_signal_type)
4951
@@ -53,7 +55,9 @@ An `UpdateSignalHandler` that performs the update whenever it is `run!` with an
5355
struct UpdateEvery{U <: UpdateSignal} <: UpdateSignalHandler end
5456
UpdateEvery(::Type{U}) where {U} = UpdateEvery{U}()
5557

56-
run!(alg::UpdateEvery{U}, cache, ::U, f!, args...) where {U} = f!(args...)
58+
allocate_cache(::UpdateEvery, _) = nothing
59+
60+
run!(alg::UpdateEvery{U}, cache, ::U, f!, args...) where {U <: UpdateSignal} = f!(args...)
5761

5862
"""
5963
UpdateEveryN(n, update_signal_type, reset_signal_type = Nothing)

test/test_convergence_checker.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import ClimaTimeSteppers as CTS
1414
last_iters1 = (11, 9, 10, 12)
1515
last_iters2 = (12, 10, 10, 13)
1616
function test_func(checker, last_iter1, last_iter2)
17-
cache = allocate_cache(checker, val_func(0))
17+
cache = CTS.allocate_cache(checker, val_func(0))
1818
for (err_func, last_iter) in ((err_func1, last_iter1), (err_func2, last_iter2))
1919
for iter in 0:(last_iter - 1)
2020
CTS.check_convergence!(checker, cache, val_func(iter), err_func(iter), iter) && return false

test/test_newtons_method.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ end
6060
for (alg, use_j) in ((alg1, true), (alg2, true), (alg3, false))
6161
x = copy(x_init)
6262
j_prototype = similar(x, length(x), length(x))
63-
cache = allocate_cache(alg, x, use_j ? j_prototype : nothing)
63+
cache = CTS.allocate_cache(alg, x, use_j ? j_prototype : nothing)
6464
CTS.solve_newton!(alg, cache, x, f!, use_j ? j! : nothing)
6565
@test norm(x .- x_exact) / norm(x_exact) < rtol
6666
end

0 commit comments

Comments
 (0)