From 04bf6a49d6f9425c32a5a304ec61b44a3f6d5f2b Mon Sep 17 00:00:00 2001 From: Charles Kawczynski Date: Tue, 10 Oct 2023 15:47:07 -0700 Subject: [PATCH] Always compute+add T_exp and T_lim --- src/functions.jl | 34 +++++++++++++++++++------ src/solvers/imex_ark.jl | 52 +++++++++++++++++---------------------- src/solvers/imex_ssprk.jl | 46 +++++++++++++--------------------- 3 files changed, 65 insertions(+), 67 deletions(-) diff --git a/src/functions.jl b/src/functions.jl index 4e80446c..f9320f0b 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -4,14 +4,32 @@ export ClimaODEFunction, ForwardEulerODEFunction abstract type AbstractClimaODEFunction <: DiffEqBase.AbstractODEFunction{true} end -Base.@kwdef struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction - T_lim!::TL = nothing # nothing or (uₜ, u, p, t) -> ... - T_exp!::TE = nothing # nothing or (uₜ, u, p, t) -> ... - T_imp!::TI = nothing # nothing or (uₜ, u, p, t) -> ... - lim!::L = (u, p, t, u_ref) -> nothing - dss!::D = (u, p, t) -> nothing - post_explicit!::PE = (u, p, t) -> nothing - post_implicit!::PI = (u, p, t) -> nothing +struct ClimaODEFunction{TL, TE, TI, L, D, PE, PI} <: AbstractClimaODEFunction + T_lim!::TL + T_exp!::TE + T_imp!::TI + lim!::L + dss!::D + post_explicit!::PE + post_implicit!::PI +end +function ClimaODEFunction(; + T_lim! = nothing, + T_exp! = nothing, + T_imp! = nothing, + lim! = nothing, + dss! = nothing, + post_explicit! = nothing, + post_implicit! = nothing, +) + isnothing(T_lim!) && (T_lim! = (uₜ, u, p, t) -> nothing) + isnothing(T_exp!) && (T_exp! = (uₜ, u, p, t) -> nothing) + T_imp! = nothing + isnothing(lim!) && (lim! = (u, p, t, u_ref) -> nothing) + isnothing(dss!) && (dss! = (u, p, t) -> nothing) + isnothing(post_explicit!) && (post_explicit! = (u, p, t) -> nothing) + isnothing(post_implicit!) && (post_implicit! = (u, p, t) -> nothing) + return ClimaODEFunction(T_lim!, T_exp!, T_imp!, lim!, dss!, post_explicit!, post_implicit!) end # Don't wrap a AbstractClimaODEFunction in an ODEFunction (makes ODEProblem work). diff --git a/src/solvers/imex_ark.jl b/src/solvers/imex_ark.jl index 32253ecf..b6f21304 100644 --- a/src/solvers/imex_ark.jl +++ b/src/solvers/imex_ark.jl @@ -74,19 +74,17 @@ function step_u!(integrator, cache::IMEXARKCache) @. U = u - if !isnothing(T_lim!) # Update based on limited tendencies from previous stages - for j in 1:(i - 1) - iszero(a_exp[i, j]) && continue - @. U += dt * a_exp[i, j] * T_lim[j] - end - lim!(U, p, t_exp, u) + # Update based on limited tendencies from previous stages + for j in 1:(i - 1) + iszero(a_exp[i, j]) && continue + @. U += dt * a_exp[i, j] * T_lim[j] end + lim!(U, p, t_exp, u) - if !isnothing(T_exp!) # Update based on explicit tendencies from previous stages - for j in 1:(i - 1) - iszero(a_exp[i, j]) && continue - @. U += dt * a_exp[i, j] * T_exp[j] - end + # Update based on explicit tendencies from previous stages + for j in 1:(i - 1) + iszero(a_exp[i, j]) && continue + @. U += dt * a_exp[i, j] * T_exp[j] end if !isnothing(T_imp!) # Update based on implicit tendencies from previous stages @@ -147,32 +145,26 @@ function step_u!(integrator, cache::IMEXARKCache) end if !all(iszero, a_exp[:, i]) || !iszero(b_exp[i]) - if !isnothing(T_lim!) - T_lim!(T_lim[i], U, p, t_exp) - end - if !isnothing(T_exp!) - T_exp!(T_exp[i], U, p, t_exp) - end + T_lim!(T_lim[i], U, p, t_exp) + T_exp!(T_exp[i], U, p, t_exp) end end t_final = t + dt - if !isnothing(T_lim!) # Update based on limited tendencies from previous stages - @. temp = u - for j in 1:s - iszero(b_exp[j]) && continue - @. temp += dt * b_exp[j] * T_lim[j] - end - lim!(temp, p, t_final, u) - @. u = temp + # Update based on limited tendencies from previous stages + @. temp = u + for j in 1:s + iszero(b_exp[j]) && continue + @. temp += dt * b_exp[j] * T_lim[j] end + lim!(temp, p, t_final, u) + @. u = temp - if !isnothing(T_exp!) # Update based on explicit tendencies from previous stages - for j in 1:s - iszero(b_exp[j]) && continue - @. u += dt * b_exp[j] * T_exp[j] - end + # Update based on explicit tendencies from previous stages + for j in 1:s + iszero(b_exp[j]) && continue + @. u += dt * b_exp[j] * T_exp[j] end if !isnothing(T_imp!) # Update based on implicit tendencies from previous stages diff --git a/src/solvers/imex_ssprk.jl b/src/solvers/imex_ssprk.jl index 646889ba..0aced115 100644 --- a/src/solvers/imex_ssprk.jl +++ b/src/solvers/imex_ssprk.jl @@ -19,13 +19,13 @@ function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{SSP} s = length(b_exp) inds = ntuple(i -> i, s) inds_T_imp = filter(i -> !all(iszero, a_imp[:, i]) || !iszero(b_imp[i]), inds) - U = similar(u0) - U_exp = similar(u0) - T_lim = similar(u0) - T_exp = similar(u0) - U_lim = similar(u0) - T_imp = SparseContainer(map(i -> similar(u0), collect(1:length(inds_T_imp))), inds_T_imp) - temp = similar(u0) + U = zero(u0) + U_exp = zero(u0) + T_lim = zero(u0) + T_exp = zero(u0) + U_lim = zero(u0) + T_imp = SparseContainer(map(i -> zero(u0), collect(1:length(inds_T_imp))), inds_T_imp) + temp = zero(u0) â_exp = vcat(a_exp, b_exp') β = diag(â_exp, -1) for i in 1:length(β) @@ -83,14 +83,10 @@ function step_u!(integrator, cache::IMEXSSPRKCache) if i == 1 @. U_exp = u elseif !iszero(β[i - 1]) - if !isnothing(T_lim!) - @. U_lim = U_exp + dt * T_lim - lim!(U_lim, p, t_exp, U_exp) - @. U_exp = U_lim - end - if !isnothing(T_exp!) - @. U_exp += dt * T_exp - end + @. U_lim = U_exp + dt * T_lim + lim!(U_lim, p, t_exp, U_exp) + @. U_exp = U_lim + @. U_exp += dt * T_exp @. U_exp = (1 - β[i - 1]) * u + β[i - 1] * U_exp end @@ -153,26 +149,18 @@ function step_u!(integrator, cache::IMEXSSPRKCache) end if !iszero(β[i]) - if !isnothing(T_lim!) - T_lim!(T_lim, U, p, t_exp) - end - if !isnothing(T_exp!) - T_exp!(T_exp, U, p, t_exp) - end + T_lim!(T_lim, U, p, t_exp) + T_exp!(T_exp, U, p, t_exp) end end t_final = t + dt if !iszero(β[s]) - if !isnothing(T_lim!) - @. U_lim = U_exp + dt * T_lim - lim!(U_lim, p, t_final, U_exp) - @. U_exp = U_lim - end - if !isnothing(T_exp!) - @. U_exp += dt * T_exp - end + @. U_lim = U_exp + dt * T_lim + lim!(U_lim, p, t_final, U_exp) + @. U_exp = U_lim + @. U_exp += dt * T_exp @. u = (1 - β[s]) * u + β[s] * U_exp end