|
| 1 | +has_jac(T_imp!) = |
| 2 | + hasfield(typeof(T_imp!), :Wfact) && |
| 3 | + hasfield(typeof(T_imp!), :jac_prototype) && |
| 4 | + !isnothing(T_imp!.Wfact) && |
| 5 | + !isnothing(T_imp!.jac_prototype) |
| 6 | + |
| 7 | +imp_error(name) = error("$(isnothing(name) ? "The given IMEXTableau" : name) \ |
| 8 | + has implicit stages that require a nonlinear solver, \ |
| 9 | + so NewtonsMethod must be specified alongside T_imp!.") |
| 10 | + |
| 11 | +sdirk_error(name) = error("$(isnothing(name) ? "The given IMEXTableau" : name) \ |
| 12 | + has implicit stages with distinct coefficients (it \ |
| 13 | + is not SDIRK), and an update is required whenever a \ |
| 14 | + stage has a different coefficient from the previous \ |
| 15 | + stage. Do not update on the NewTimeStep signal when \ |
| 16 | + using $(isnothing(name) ? "this tableau" : name).") |
| 17 | + |
| 18 | +struct IMEXARKCache{S, N} |
| 19 | + stage_cache::S |
| 20 | + newtons_method_cache::N |
| 21 | +end |
| 22 | + |
| 23 | +function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{Unconstrained}; kwargs...) |
| 24 | + (; u0) = prob |
| 25 | + (; T_lim!, T_exp!, T_exp_T_lim!, T_imp!) = prob.f |
| 26 | + (; name, newtons_method) = alg |
| 27 | + (; a_exp, b_exp, a_imp, b_imp) = alg.tableau |
| 28 | + |
| 29 | + no_T_lim = isnothing(T_lim!) && isnothing(T_exp_T_lim!) |
| 30 | + no_T_exp = isnothing(T_exp!) && isnothing(T_exp_T_lim!) |
| 31 | + no_T_imp = isnothing(T_imp!) |
| 32 | + |
| 33 | + s = length(a_imp, 1) # number of stages |
| 34 | + |
| 35 | + A_exp = vcat(a_exp, b_exp') # exp coefs with final state seen as stage s + 1 |
| 36 | + A_imp = vcat(a_imp, b_imp') # imp coefs with final state seen as stage s + 1 |
| 37 | + Γ = a_imp[nz_stages, nz_stages] # "fully implicit" part of a_imp |
| 38 | + sdirk_γ = length(unique(diag(Γ))) == 1 ? diag(Γ)[1] : nothing |
| 39 | + |
| 40 | + z_stages = findall(iszero, diag(a_imp)) # stages without implicit solves |
| 41 | + nz_stages = findall(!iszero, diag(a_imp)) # stages with implicit solves |
| 42 | + exp_stages = findall(col -> any(!iszero, col), eachcol(A_exp)) |
| 43 | + imp_z_stages = findall(col -> any(!iszero, col), eachcol(A_imp[:, z_stages])) |
| 44 | + |
| 45 | + temp_value = similar(u0) |
| 46 | + T_lim_by_stage = no_T_lim ? SparseTuple() : SparseTuple(map(_ -> similar(u0), exp_stages), exp_stages) |
| 47 | + T_exp_by_stage = no_T_exp ? SparseTuple() : SparseTuple(map(_ -> similar(u0), exp_stages), exp_stages) |
| 48 | + T_imp_by_stage = no_T_imp ? SparseTuple() : SparseTuple(map(_ -> similar(u0), imp_z_stages), imp_z_stages) |
| 49 | + ΔU_imp_by_stage = no_T_imp ? SparseTuple() : SparseTuple(map(_ -> similar(u0), nz_stages), nz_stages) |
| 50 | + |
| 51 | + isnothing(newtons_method) && !isempty(ΔU_imp_by_stage) && imp_error(name) |
| 52 | + |
| 53 | + prev_ΔtT_lim_to_Δu_coefs_by_stage = no_T_lim ? SparseTuple() : sparse_matrix_rows(A_exp, 1:(s + 1), 1:s) |
| 54 | + prev_ΔtT_exp_to_Δu_coefs_by_stage = no_T_exp ? SparseTuple() : sparse_matrix_rows(A_exp, 1:(s + 1), 1:s) |
| 55 | + prev_z_ΔtT_imp_to_Δu_coefs_by_stage = |
| 56 | + no_T_imp ? SparseTuple() : sparse_matrix_rows(A_imp[:, z_stages], 1:(s + 1), z_stages) |
| 57 | + |
| 58 | + prev_nz_ΔtT_imp_to_Δu_coef_matrix = A_imp[:, nz_stages] |
| 59 | + prev_nz_ΔtT_imp_to_Δu_coef_matrix[nz_stages, nz_stages] .= Γ - Diagonal(Γ) |
| 60 | + prev_nz_ΔU_imp_to_Δu_coefs_by_stage = |
| 61 | + no_T_imp ? SparseTuple() : sparse_matrix_rows(prev_nz_ΔtT_imp_to_Δu_coef_matrix * inv(Γ), 1:(s + 1), nz_stages) |
| 62 | + |
| 63 | + # Convert all inputs to unrolled_foreach in step_u! into regular tuples. |
| 64 | + T_lim_by_stage_dense = dense_tuple(T_lim_by_stage, s, nothing) |
| 65 | + T_exp_by_stage_dense = dense_tuple(T_exp_by_stage, s, nothing) |
| 66 | + T_imp_by_stage_dense = dense_tuple(T_imp_by_stage, s, nothing) |
| 67 | + ΔU_imp_by_stage_dense = dense_tuple(ΔU_imp_by_stage, s, nothing) |
| 68 | + prev_ΔtT_lim_to_Δu_coefs_by_stage_dense = dense_tuple(prev_ΔtT_lim_to_Δu_coefs_by_stage, s, SparseTuple()) |
| 69 | + prev_ΔtT_exp_to_Δu_coefs_by_stage_dense = dense_tuple(prev_ΔtT_exp_to_Δu_coefs_by_stage, s, SparseTuple()) |
| 70 | + prev_z_ΔtT_imp_to_Δu_coefs_by_stage_dense = dense_tuple(prev_z_ΔtT_imp_to_Δu_coefs_by_stage, s, SparseTuple()) |
| 71 | + prev_nz_ΔU_imp_to_Δu_coefs_by_stage_dense = dense_tuple(prev_nz_ΔU_imp_to_Δu_coefs_by_stage, s, SparseTuple()) |
| 72 | + |
| 73 | + stage_cache = (; |
| 74 | + s, |
| 75 | + sdirk_γ, |
| 76 | + temp_value, |
| 77 | + T_lim_by_stage, |
| 78 | + T_exp_by_stage, |
| 79 | + T_imp_by_stage, |
| 80 | + ΔU_imp_by_stage, |
| 81 | + T_lim_by_stage_dense, |
| 82 | + T_exp_by_stage_dense, |
| 83 | + T_imp_by_stage_dense, |
| 84 | + ΔU_imp_by_stage_dense, |
| 85 | + prev_ΔtT_lim_to_Δu_coefs_by_stage_dense, |
| 86 | + prev_ΔtT_exp_to_Δu_coefs_by_stage_dense, |
| 87 | + prev_z_ΔtT_imp_to_Δu_coefs_by_stage_dense, |
| 88 | + prev_nz_ΔU_imp_to_Δu_coefs_by_stage_dense, |
| 89 | + ) |
| 90 | + |
| 91 | + newtons_method_cache = |
| 92 | + isnothing(newtons_method) ? nothing : |
| 93 | + allocate_cache(newtons_method, u0, has_jac(T_imp!) ? T_imp!.jac_prototype : nothing) |
| 94 | + |
| 95 | + return IMEXARKCache(stage_cache, newtons_method_cache) |
| 96 | +end |
| 97 | + |
| 98 | +function step_u!(integrator, cache::IMEXARKCache) |
| 99 | + (; u, p, t, alg) = integrator |
| 100 | + (; T_lim!, T_exp!, T_exp_T_lim!, T_imp!) = integrator.sol.prob.f |
| 101 | + (; lim!, dss!, post_explicit!, post_implicit!) = integrator.sol.prob.f |
| 102 | + (; name, newtons_method) = alg |
| 103 | + (; a_exp, b_exp, a_imp, b_imp, c_exp, c_imp) = alg.tableau |
| 104 | + (; newtons_method_cache) = cache |
| 105 | + (; |
| 106 | + s, |
| 107 | + sdirk_γ, |
| 108 | + temp_value, |
| 109 | + T_lim_by_stage, |
| 110 | + T_exp_by_stage, |
| 111 | + T_imp_by_stage, |
| 112 | + ΔU_imp_by_stage, |
| 113 | + T_lim_by_stage_dense, |
| 114 | + T_exp_by_stage_dense, |
| 115 | + T_imp_by_stage_dense, |
| 116 | + ΔU_imp_by_stage_dense, |
| 117 | + prev_ΔtT_lim_to_Δu_coefs_by_stage_dense, |
| 118 | + prev_ΔtT_exp_to_Δu_coefs_by_stage_dense, |
| 119 | + prev_z_ΔtT_imp_to_Δu_coefs_by_stage_dense, |
| 120 | + prev_nz_ΔU_imp_to_Δu_coefs_by_stage_dense, |
| 121 | + ) = cache.stage_cache |
| 122 | + |
| 123 | + Δt = integrator.dt |
| 124 | + |
| 125 | + if !isnothing(T_imp!) && !isnothing(newtons_method) |
| 126 | + (; update_j) = newtons_method |
| 127 | + (; j) = newtons_method_cache |
| 128 | + if !isnothing(j) && needs_update!(update_j, NewTimeStep(t)) |
| 129 | + isnothing(sdirk_γ) && sdirk_error(name) |
| 130 | + T_imp!.Wfact(j, u, p, Δt * sdirk_γ, t) |
| 131 | + end |
| 132 | + end |
| 133 | + |
| 134 | + unrolled_foreach( |
| 135 | + 1:s, |
| 136 | + T_lim_by_stage_dense[1:s], |
| 137 | + T_exp_by_stage_dense[1:s], |
| 138 | + T_imp_by_stage_dense[1:s], |
| 139 | + ΔU_imp_by_stage_dense[1:s], |
| 140 | + prev_ΔtT_lim_to_Δu_coefs_by_stage_dense[1:s], |
| 141 | + prev_ΔtT_exp_to_Δu_coefs_by_stage_dense[1:s], |
| 142 | + prev_z_ΔtT_imp_to_Δu_coefs_by_stage_dense[1:s], |
| 143 | + prev_nz_ΔU_imp_to_Δu_coefs_by_stage_dense[1:s], |
| 144 | + ) do ( |
| 145 | + stage, |
| 146 | + T_lim, |
| 147 | + T_exp, |
| 148 | + T_imp, |
| 149 | + ΔU_imp, |
| 150 | + prev_ΔtT_lim_to_Δu_coefs, |
| 151 | + prev_ΔtT_exp_to_Δu_coefs, |
| 152 | + prev_z_ΔtT_imp_to_Δu_coefs, |
| 153 | + prev_nz_ΔU_imp_to_Δu_coefs, |
| 154 | + ) |
| 155 | + t_exp = t + Δt * c_exp[stage] |
| 156 | + t_imp = t + Δt * c_imp[stage] |
| 157 | + Δtγ = Δt * a_imp[stage, stage] |
| 158 | + |
| 159 | + Δu_over_Δt_from_prev_T_lim = broadcasted_dot(prev_ΔtT_lim_to_Δu_coefs, T_lim_by_stage) |
| 160 | + Δu_over_Δt_from_prev_T_exp = broadcasted_dot(prev_ΔtT_exp_to_Δu_coefs, T_exp_by_stage) |
| 161 | + Δu_over_Δt_from_prev_T_imp = broadcasted_dot(prev_z_ΔtT_imp_to_Δu_coefs, T_imp_by_stage) |
| 162 | + Δu_from_prev_ΔU_imp = broadcasted_dot(prev_nz_ΔU_imp_to_Δu_coefs, ΔU_imp_by_stage) |
| 163 | + |
| 164 | + if isempty(prev_ΔtT_lim_to_Δu_coefs) |
| 165 | + u_plus_Δu_lim = u |
| 166 | + else |
| 167 | + u_plus_Δu_lim = temp_value |
| 168 | + @. u_plus_Δu_lim = u + Δt * Δu_over_Δt_from_prev_T_lim |
| 169 | + lim!(u_plus_Δu_lim, p, t_exp, u) |
| 170 | + end |
| 171 | + |
| 172 | + if ( |
| 173 | + isempty(prev_ΔtT_exp_to_Δu_coefs) && |
| 174 | + isempty(prev_z_ΔtT_imp_to_Δu_coefs) && |
| 175 | + isempty(prev_nz_ΔU_imp_to_Δu_coefs) |
| 176 | + ) |
| 177 | + U_before_solve = u_plus_Δu_lim |
| 178 | + else |
| 179 | + U_before_solve = temp_value |
| 180 | + @. U_before_solve = |
| 181 | + u_plus_Δu_lim + Δt * (Δu_over_Δt_from_prev_T_exp + Δu_over_Δt_from_prev_T_imp) + Δu_from_prev_ΔU_imp |
| 182 | + end |
| 183 | + |
| 184 | + post_explicit!(U_before_solve, p, t_imp) |
| 185 | + |
| 186 | + if !isnothing(ΔU_imp) |
| 187 | + # Use ΔU_imp as additional temporary storage, since its value does |
| 188 | + # not need to be set until the end of each stage. |
| 189 | + U = ΔU_imp |
| 190 | + @. U = U_before_solve |
| 191 | + |
| 192 | + # Solve U ≈ U_before_solve + Δtγ * T_imp(U, p, t_imp) for U. |
| 193 | + set_residual! = (residual, U) -> begin |
| 194 | + T_imp!(residual, U, p, t_imp) |
| 195 | + @. residual = U_before_solve - U + Δtγ * residual |
| 196 | + end |
| 197 | + set_jacobian! = (jacobian, U) -> T_imp!.Wfact(jacobian, U, p, Δtγ, t_imp) |
| 198 | + post_implicit! = U -> post_implicit!(U, p, t_imp) |
| 199 | + solve_newton!( |
| 200 | + newtons_method, |
| 201 | + newtons_method_cache, |
| 202 | + U, |
| 203 | + set_residual!, |
| 204 | + set_jacobian!, |
| 205 | + post_implicit!, |
| 206 | + post_implicit!, |
| 207 | + ) |
| 208 | + else |
| 209 | + U = U_before_solve # There is no solve on this stage. |
| 210 | + end |
| 211 | + |
| 212 | + if !isnothing(T_lim) || !isnothing(T_exp) |
| 213 | + if !isnothing(T_exp_T_lim!) |
| 214 | + T_exp_T_lim!(T_exp, T_lim, U, p, t_exp) |
| 215 | + dss!(T_lim, p, t_exp) |
| 216 | + dss!(T_exp, p, t_exp) |
| 217 | + end |
| 218 | + if !isnothing(T_lim!) |
| 219 | + T_lim!(T_lim, U, p, t_exp) |
| 220 | + dss!(T_lim, p, t_exp) |
| 221 | + end |
| 222 | + if !isnothing(T_exp!) |
| 223 | + T_exp!(T_exp, U, p, t_exp) |
| 224 | + dss!(T_exp, p, t_exp) |
| 225 | + end |
| 226 | + # TODO: Can we just use T_exp_T_lim!, and not 3 different functions? |
| 227 | + # TODO: Fuse the DSS calls above into a single operation. |
| 228 | + end |
| 229 | + |
| 230 | + if !isnothing(ΔU_imp) |
| 231 | + # TODO: Subtract the T_lim contribution from this term. |
| 232 | + @. ΔU_imp = U - u - Δt * Δu_over_Δt_from_prev_T_exp |
| 233 | + elseif !isnothing(T_imp) |
| 234 | + T_imp!(T_imp, U, p, t_imp) |
| 235 | + end |
| 236 | + end |
| 237 | + |
| 238 | + t_final = t + Δt |
| 239 | + prev_ΔtT_lim_to_Δu_coefs_final = prev_ΔtT_lim_to_Δu_coefs_by_stage_dense[s + 1] |
| 240 | + prev_ΔtT_exp_to_Δu_coefs_final = prev_ΔtT_exp_to_Δu_coefs_by_stage_dense[s + 1] |
| 241 | + prev_z_ΔtT_imp_to_Δu_coefs_final = prev_z_ΔtT_imp_to_Δu_coefs_by_stage_dense[s + 1] |
| 242 | + prev_nz_ΔU_imp_to_Δu_coefs_final = prev_nz_ΔU_imp_to_Δu_coefs_by_stage_dense[s + 1] |
| 243 | + Δu_over_Δt_from_prev_T_lim_final = broadcasted_dot(prev_ΔtT_lim_to_Δu_coefs_final, T_lim_by_stage) |
| 244 | + Δu_over_Δt_from_prev_T_exp_final = broadcasted_dot(prev_ΔtT_exp_to_Δu_coefs_final, T_exp_by_stage) |
| 245 | + Δu_over_Δt_from_prev_T_imp_final = broadcasted_dot(prev_z_ΔtT_imp_to_Δu_coefs_final, T_imp_by_stage) |
| 246 | + Δu_from_prev_ΔU_imp_final = broadcasted_dot(prev_nz_ΔU_imp_to_Δu_coefs_final, ΔU_imp_by_stage) |
| 247 | + |
| 248 | + if isempty(prev_ΔtT_lim_to_Δu_coefs_final) |
| 249 | + u_plus_Δu_lim = u |
| 250 | + else |
| 251 | + u_plus_Δu_lim = temp_value |
| 252 | + @. u_plus_Δu_lim = u + Δt * Δu_over_Δt_from_prev_T_lim_final |
| 253 | + lim!(u_plus_Δu_lim, p, t_final, u) |
| 254 | + end |
| 255 | + |
| 256 | + @. u = |
| 257 | + u_plus_Δu_lim + |
| 258 | + Δt * (Δu_over_Δt_from_prev_T_exp_final + Δu_over_Δt_from_prev_T_imp_final) + |
| 259 | + Δu_from_prev_ΔU_imp_final |
| 260 | + dss!(u, p, t_final) |
| 261 | + post_explicit!(u, p, t_final) |
| 262 | + |
| 263 | + return u |
| 264 | +end |
0 commit comments