Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions src/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,32 @@ export ClimaODEFunction, ForwardEulerODEFunction

abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end

Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction
T_lim!::TL = nothing # nothing or (uₜ, u, p, t) -> ...
T_exp!::TE = nothing # nothing or (uₜ, u, p, t) -> ...
T_imp!::TI = nothing # nothing or (uₜ, u, p, t) -> ...
lim!::L = (u, p, t, u_ref) -> nothing
dss!::D = (u, p, t) -> nothing
post_explicit!::PE = (u, p, t) -> nothing
post_implicit!::PI = (u, p, t) -> nothing
struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction
T_lim!::TL
T_exp!::TE
T_imp!::TI
lim!::L
dss!::D
post_explicit!::PE
post_implicit!::PI
end
function ClimaODEFunction(;
T_lim! = nothing,
T_exp! = nothing,
T_imp! = nothing,
lim! = nothing,
dss! = nothing,
post_explicit! = nothing,
post_implicit! = nothing,
)
isnothing(T_lim!) && (T_lim! = (uₜ, u, p, t) -> nothing)
isnothing(T_exp!) && (T_exp! = (uₜ, u, p, t) -> nothing)
T_imp! = nothing
isnothing(lim!) && (lim! = (u, p, t, u_ref) -> nothing)
isnothing(dss!) && (dss! = (u, p, t) -> nothing)
isnothing(post_explicit!) && (post_explicit! = (u, p, t) -> nothing)
isnothing(post_implicit!) && (post_implicit! = (u, p, t) -> nothing)
return ClimaODEFunction(T_lim!, T_exp!, T_imp!, lim!, dss!, post_explicit!, post_implicit!)
end

# Don't wrap a AbstractClimaODEFunction in an ODEFunction (makes ODEProblem work).
Expand Down
52 changes: 22 additions & 30 deletions src/solvers/imex_ark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,17 @@ function step_u!(integrator, cache::IMEXARKCache)

@. U = u

if !isnothing(T_lim!) # Update based on limited tendencies from previous stages
for j in 1:(i - 1)
iszero(a_exp[i, j]) && continue
@. U += dt * a_exp[i, j] * T_lim[j]
end
lim!(U, p, t_exp, u)
# Update based on limited tendencies from previous stages
for j in 1:(i - 1)
iszero(a_exp[i, j]) && continue
@. U += dt * a_exp[i, j] * T_lim[j]
end
lim!(U, p, t_exp, u)

if !isnothing(T_exp!) # Update based on explicit tendencies from previous stages
for j in 1:(i - 1)
iszero(a_exp[i, j]) && continue
@. U += dt * a_exp[i, j] * T_exp[j]
end
# Update based on explicit tendencies from previous stages
for j in 1:(i - 1)
iszero(a_exp[i, j]) && continue
@. U += dt * a_exp[i, j] * T_exp[j]
end

if !isnothing(T_imp!) # Update based on implicit tendencies from previous stages
Expand Down Expand Up @@ -147,32 +145,26 @@ function step_u!(integrator, cache::IMEXARKCache)
end

if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i])
if !isnothing(T_lim!)
T_lim!(T_lim[i], U, p, t_exp)
end
if !isnothing(T_exp!)
T_exp!(T_exp[i], U, p, t_exp)
end
T_lim!(T_lim[i], U, p, t_exp)
T_exp!(T_exp[i], U, p, t_exp)
end
end

t_final = t + dt

if !isnothing(T_lim!) # Update based on limited tendencies from previous stages
@. temp = u
for j in 1:s
iszero(b_exp[j]) && continue
@. temp += dt * b_exp[j] * T_lim[j]
end
lim!(temp, p, t_final, u)
@. u = temp
# Update based on limited tendencies from previous stages
@. temp = u
for j in 1:s
iszero(b_exp[j]) && continue
@. temp += dt * b_exp[j] * T_lim[j]
end
lim!(temp, p, t_final, u)
@. u = temp

if !isnothing(T_exp!) # Update based on explicit tendencies from previous stages
for j in 1:s
iszero(b_exp[j]) && continue
@. u += dt * b_exp[j] * T_exp[j]
end
# Update based on explicit tendencies from previous stages
for j in 1:s
iszero(b_exp[j]) && continue
@. u += dt * b_exp[j] * T_exp[j]
end

if !isnothing(T_imp!) # Update based on implicit tendencies from previous stages
Expand Down
46 changes: 17 additions & 29 deletions src/solvers/imex_ssprk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{SSP}
s = length(b_exp)
inds = ntuple(i -> i, s)
inds_T_imp = filter(i -> !all(iszero, a_imp[:, i]) || !iszero(b_imp[i]), inds)
U = similar(u0)
U_exp = similar(u0)
T_lim = similar(u0)
T_exp = similar(u0)
U_lim = similar(u0)
T_imp = SparseContainer(map(i -> similar(u0), collect(1:length(inds_T_imp))), inds_T_imp)
temp = similar(u0)
U = zero(u0)
U_exp = zero(u0)
T_lim = zero(u0)
T_exp = zero(u0)
U_lim = zero(u0)
T_imp = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_imp))), inds_T_imp)
temp = zero(u0)
â_exp = vcat(a_exp, b_exp')
β = diag(â_exp, -1)
for i in 1:length(β)
Expand Down Expand Up @@ -83,14 +83,10 @@ function step_u!(integrator, cache::IMEXSSPRKCache)
if i == 1
@. U_exp = u
elseif !iszero(β[i - 1])
if !isnothing(T_lim!)
@. U_lim = U_exp + dt * T_lim
lim!(U_lim, p, t_exp, U_exp)
@. U_exp = U_lim
end
if !isnothing(T_exp!)
@. U_exp += dt * T_exp
end
@. U_lim = U_exp + dt * T_lim
lim!(U_lim, p, t_exp, U_exp)
@. U_exp = U_lim
@. U_exp += dt * T_exp
@. U_exp = (1 - β[i - 1]) * u + β[i - 1] * U_exp
end

Expand Down Expand Up @@ -153,26 +149,18 @@ function step_u!(integrator, cache::IMEXSSPRKCache)
end

if !iszero(β[i])
if !isnothing(T_lim!)
T_lim!(T_lim, U, p, t_exp)
end
if !isnothing(T_exp!)
T_exp!(T_exp, U, p, t_exp)
end
T_lim!(T_lim, U, p, t_exp)
T_exp!(T_exp, U, p, t_exp)
end
end

t_final = t + dt

if !iszero(β[s])
if !isnothing(T_lim!)
@. U_lim = U_exp + dt * T_lim
lim!(U_lim, p, t_final, U_exp)
@. U_exp = U_lim
end
if !isnothing(T_exp!)
@. U_exp += dt * T_exp
end
@. U_lim = U_exp + dt * T_lim
lim!(U_lim, p, t_final, U_exp)
@. U_exp = U_lim
@. U_exp += dt * T_exp
@. u = (1 - β[s]) * u + β[s] * U_exp
end

Expand Down