Skip to content

Commit ae9a03a

Browse files
authored
Merge pull request #360 from CliMA/kp/itime
Add support for ITime in IMEXAlgorithms and SSPKnoth
2 parents 207704b + 0794890 commit ae9a03a

File tree

8 files changed

+38
-13
lines changed

8 files changed

+38
-13
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ v0.8.2
77
- ![][badge-💥breaking] If saveat is a number, then it does not automatically expand to `tspan[1]:saveat:tspan[2]`. To fix this, update
88
`saveat`, which is a keyword in the integrator, to be an array. For example, if `saveat` is a scalar, replace it with
99
`[tspan[1]:saveat:tspan[2]..., tspan[2]]` to achieve the same behavior as before.
10+
- IMEXAlgorithms and SSPKnoth are compatible with ITime. See ClimaUtilities for more information about ITime.
1011

1112
v0.7.18
1213
-------

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ClimaTimeSteppers"
22
uuid = "595c0a79-7f3d-439a-bc5a-b232dc3bde79"
33
authors = ["Climate Modeling Alliance"]
4-
version = "0.8.1"
4+
version = "0.8.2"
55

66
[deps]
77
ClimaComms = "3a4d1b5c-c61d-41fd-a00a-5873ba7a1b0d"

src/ClimaTimeSteppers.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ include("solvers/rosenbrock.jl")
128128

129129
include("Callbacks.jl")
130130

131+
include("arbitrary_number_types.jl")
131132

132133
benchmark_step(integrator, device) =
133134
@warn "Must load CUDA, BenchmarkTools, OrderedCollections, StatsBase, PrettyTables to trigger the ClimaTimeSteppersBenchmarkToolsExt extension"

src/arbitrary_number_types.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
"""
2+
SciMLBase.allows_arbitrary_number_types(alg::T)
3+
where {T <: ClimaTimeSteppers.RosenbrockAlgorithm}
4+
5+
Return `true`. Enable RosenbrockAlgorithms to run with `ClimaUtilities.ITime`.
6+
"""
7+
function SciMLBase.allows_arbitrary_number_types(alg::T) where {T <: ClimaTimeSteppers.RosenbrockAlgorithm}
8+
true
9+
end
10+
11+
"""
12+
SciMLBase.allows_arbitrary_number_types(alg::T)
13+
where {T <: ClimaTimeSteppers.IMEXAlgorithm}
14+
15+
Return `true`. Enable IMEXAlgorithms to run with `ClimaUtilities.ITime`.
16+
"""
17+
function SciMLBase.allows_arbitrary_number_types(alg::T) where {T <: ClimaTimeSteppers.IMEXAlgorithm}
18+
true
19+
end

src/integrators.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ end
6767

6868
# helper function for setting up min/max heaps for tstops and saveat
6969
function tstops_and_saveat_heaps(t0, tf, tstops, saveat = [])
70-
FT = typeof(tf)
70+
# We promote to a common type to ensure that t0 and tf have the same type
71+
FT = typeof(first(promote(t0, tf)))
7172
ordering = tf > t0 ? DataStructures.FasterForward : DataStructures.FasterReverse
7273

7374
# ensure that tstops includes tf and only has values ahead of t0
@@ -81,7 +82,7 @@ function tstops_and_saveat_heaps(t0, tf, tstops, saveat = [])
8182
return tstops, saveat
8283
end
8384

84-
compute_tdir(ts) = ts[1] > ts[end] ? sign(ts[end] - ts[1]) : eltype(ts)(1)
85+
compute_tdir(ts) = ts[1] > ts[end] ? sign(ts[end] - ts[1]) : oneunit(ts[1])
8586

8687
# called by DiffEqBase.init and DiffEqBase.solve
8788
function DiffEqBase.__init(
@@ -102,8 +103,10 @@ function DiffEqBase.__init(
102103
)
103104
(; u0, p) = prob
104105
t0, tf = prob.tspan
106+
t0, tf, dt = promote(t0, tf, dt)
105107

106-
dt > zero(dt) || error("dt must be positive")
108+
# We need zero(oneunit()) because there's no zerounit
109+
dt > zero(oneunit(dt)) || error("dt must be positive")
107110
_dt = dt
108111
dt = tf > t0 ? dt : -dt
109112

@@ -243,8 +246,9 @@ function __step!(integrator)
243246
# is taken from OrdinaryDiffEq.jl
244247
t_plus_dt = integrator.t + integrator.dt
245248
t_unit = oneunit(integrator.t)
246-
max_t_error = 100 * eps(float(integrator.t / t_unit)) * t_unit
247-
integrator.t = !isempty(tstops) && abs(first(tstops) - t_plus_dt) < max_t_error ? first(tstops) : t_plus_dt
249+
max_t_error = 100 * eps(float(integrator.t / t_unit)) * float(t_unit)
250+
integrator.t =
251+
!isempty(tstops) && abs(float(first(tstops)) - float(t_plus_dt)) < max_t_error ? first(tstops) : t_plus_dt
248252

249253
# apply callbacks
250254
discrete_callbacks = integrator.callback.discrete_callbacks

src/solvers/imex_ark.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ end
107107

108108
t_exp = t + dt * c_exp[i]
109109
t_imp = t + dt * c_imp[i]
110-
dtγ = dt * a_imp[i, i]
110+
dtγ = float(dt) * a_imp[i, i]
111111

112112
if has_T_lim(f) # Update based on limited tendencies from previous stages
113113
assign_fused_increment!(U, u, dt, a_exp, T_lim, Val(i))

src/solvers/rosenbrock.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages}
127127

128128
# TODO: This is only valid when Γ[i, i] is constant, otherwise we have to
129129
# move this in the for loop
130-
@inbounds dtγ = dt * Γ[1, 1]
130+
@inbounds dtγ = float(dt) * Γ[1, 1]
131131

132132
if !isnothing(T_imp!)
133133
Wfact! = int.sol.prob.f.T_imp!.Wfact
@@ -175,14 +175,14 @@ function step_u!(int, cache::RosenbrockCache{Nstages}) where {Nstages}
175175
end
176176

177177
if !isnothing(tgrad!)
178-
fU .+= γi .* dt .* ∂Y∂t
178+
fU .+= γi .* float(dt) .* ∂Y∂t
179179
end
180180

181181
for j in 1:(i - 1)
182-
fU .+= (C[i, j] / dt) .* k[j]
182+
fU .+= (C[i, j] / float(dt)) .* k[j]
183183
end
184184

185-
fU .*= -dtγ
185+
fU .*= -float(dtγ)
186186

187187
if !isnothing(T_imp!)
188188
if W isa Matrix

src/utilities/fused_increment.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ In the edge case (coeffs are zero, `j` range is empty),
121121
this lowers to `nothing` (no-op)
122122
"""
123123
@inline function fused_increment!(u, dt, sc, tend, v)
124-
bc = fused_increment(u, dt, sc, tend, v)
124+
bc = fused_increment(u, float(dt), sc, tend, v)
125125
if bc isa Base.Broadcast.Broadcasted # Only material if not trivial assignment
126126
Base.Broadcast.materialize!(u, bc)
127127
end
@@ -142,7 +142,7 @@ this lowers to
142142
`@. U = u`
143143
"""
144144
@inline function assign_fused_increment!(U, u, dt, sc, tend, v)
145-
bc = fused_increment(u, dt, sc, tend, v)
145+
bc = fused_increment(u, float(dt), sc, tend, v)
146146
Base.Broadcast.materialize!(U, bc)
147147
return nothing
148148
end

0 commit comments

Comments
 (0)