|
| 1 | +has_jac(T_imp!) = |
| 2 | + hasfield(typeof(T_imp!), :Wfact) && |
| 3 | + hasfield(typeof(T_imp!), :jac_prototype) && |
| 4 | + !isnothing(T_imp!.Wfact) && |
| 5 | + !isnothing(T_imp!.jac_prototype) |
| 6 | + |
| 7 | +imp_error(name) = error("$(isnothing(name) ? "The given IMEXTableau" : name) \ |
| 8 | + has implicit stages that require a nonlinear solver, \ |
| 9 | + so NewtonsMethod must be specified alongside T_imp!.") |
| 10 | + |
| 11 | +sdirk_error(name) = error("$(isnothing(name) ? "The given IMEXTableau" : name) \ |
| 12 | + has implicit stages with distinct coefficients (it \ |
| 13 | + is not SDIRK), and an update is required whenever a \ |
| 14 | + stage has a different coefficient from the previous \ |
| 15 | + stage. Do not update on the NewTimeStep signal when \ |
| 16 | + using $(isnothing(name) ? "this tableau" : name).") |
| 17 | + |
| 18 | +struct IMEXARKCache{T, N} |
| 19 | + timestepper_cache::T |
| 20 | + newtons_method_cache::N |
| 21 | +end |
| 22 | + |
| 23 | +function init_cache(prob::DiffEqBase.AbstractODEProblem, alg::IMEXAlgorithm{Unconstrained}; kwargs...) |
| 24 | + (; u0) = prob |
| 25 | + (; T_lim!, T_exp!, T_exp_T_lim!, T_imp!) = prob.f |
| 26 | + (; name, newtons_method) = alg |
| 27 | + (; a_exp, b_exp, a_imp, b_imp) = alg.tableau |
| 28 | + |
| 29 | + no_T_lim = isnothing(T_lim!) && isnothing(T_exp_T_lim!) |
| 30 | + no_T_exp = isnothing(T_exp!) && isnothing(T_exp_T_lim!) |
| 31 | + no_T_imp = isnothing(T_imp!) |
| 32 | + |
| 33 | + s = size(a_imp, 1) # number of internal stages |
| 34 | + |
| 35 | + # Extend the coefficient matrices a_exp and a_imp by interpreting the final |
| 36 | + # state on each step as stage s + 1. |
| 37 | + A_exp = vcat(a_exp, b_exp') |
| 38 | + A_imp = vcat(a_imp, b_imp') |
| 39 | + |
| 40 | + z_stages = findall(iszero, diag(A_imp)) # stages without implicit solves |
| 41 | + nz_stages = findall(!iszero, diag(A_imp)) # stages with implicit solves |
| 42 | + stages_needing_T_exp = findall(col -> any(!iszero, col), eachcol(A_exp)) |
| 43 | + z_stages_needing_T_imp = findall(col -> any(!iszero, col), eachcol(A_imp[:, z_stages])) |
| 44 | + # All nz stages are computed using ΔU_imp, rather than T_imp. |
| 45 | + |
| 46 | + Γ = A_imp[nz_stages, nz_stages] # "fully implicit" part of A_imp |
| 47 | + sdirk_γ = length(unique(diag(Γ))) == 1 ? diag(Γ)[1] : nothing |
| 48 | + |
| 49 | + temp_value1 = similar(u0) |
| 50 | + temp_value2 = similar(u0) |
| 51 | + T_lim_values_sparse = no_T_lim ? SparseTuple() : SparseTuple(_ -> similar(u0), stages_needing_T_exp) |
| 52 | + T_exp_values_sparse = no_T_exp ? SparseTuple() : SparseTuple(_ -> similar(u0), stages_needing_T_exp) |
| 53 | + T_imp_values_sparse = no_T_imp ? SparseTuple() : SparseTuple(_ -> similar(u0), z_stages_needing_T_imp) |
| 54 | + ΔU_imp_values_sparse = no_T_imp ? SparseTuple() : SparseTuple(_ -> similar(u0), nz_stages) |
| 55 | + |
| 56 | + ΔtT_lim_to_Δu_lim_tuples_sparse = no_T_lim ? SparseTuple() : sparse_matrix_rows(A_exp, 1:(s + 1), 1:s) |
| 57 | + ΔtT_exp_to_Δu_exp_tuples_sparse = no_T_exp ? SparseTuple() : sparse_matrix_rows(A_exp, 1:(s + 1), 1:s) |
| 58 | + ΔtT_imp_to_Δu_imp_tuples_sparse = |
| 59 | + no_T_imp ? SparseTuple() : sparse_matrix_rows(A_imp[:, z_stages], 1:(s + 1), z_stages) |
| 60 | + |
| 61 | + A_imp_lower_nz_component = A_imp[:, nz_stages] |
| 62 | + A_imp_lower_nz_component[nz_stages, nz_stages] .= Γ - Diagonal(Γ) |
| 63 | + prev_ΔU_imp_to_Δu_imp_tuples_sparse = |
| 64 | + no_T_imp ? SparseTuple() : sparse_matrix_rows(A_imp_lower_nz_component * inv(Γ), 1:(s + 1), nz_stages) |
| 65 | + |
| 66 | + # Convert all values that will be passed to unrolled_foreach in step_u! into |
| 67 | + # tuples of length s. |
| 68 | + T_lim_values = dense_tuple(T_lim_values_sparse, s, nothing) |
| 69 | + T_exp_values = dense_tuple(T_exp_values_sparse, s, nothing) |
| 70 | + T_imp_values = dense_tuple(T_imp_values_sparse, s, nothing) |
| 71 | + ΔU_imp_values = dense_tuple(ΔU_imp_values_sparse, s, nothing) |
| 72 | + ΔtT_lim_to_Δu_lim_tuples = dense_tuple(ΔtT_lim_to_Δu_lim_tuples_sparse, s, SparseTuple()) |
| 73 | + ΔtT_exp_to_Δu_exp_tuples = dense_tuple(ΔtT_exp_to_Δu_exp_tuples_sparse, s, SparseTuple()) |
| 74 | + ΔtT_imp_to_Δu_imp_tuples = dense_tuple(ΔtT_imp_to_Δu_imp_tuples_sparse, s, SparseTuple()) |
| 75 | + prev_ΔU_imp_to_Δu_imp_tuples = dense_tuple(prev_ΔU_imp_to_Δu_imp_tuples_sparse, s, SparseTuple()) |
| 76 | + |
| 77 | + timestepper_cache = (; |
| 78 | + sdirk_γ, |
| 79 | + temp_value1, |
| 80 | + temp_value2, |
| 81 | + T_lim_values_sparse, |
| 82 | + T_exp_values_sparse, |
| 83 | + T_imp_values_sparse, |
| 84 | + ΔU_imp_values_sparse, |
| 85 | + T_lim_values, |
| 86 | + T_exp_values, |
| 87 | + T_imp_values, |
| 88 | + ΔU_imp_values, |
| 89 | + ΔtT_lim_to_Δu_lim_tuples, |
| 90 | + ΔtT_exp_to_Δu_exp_tuples, |
| 91 | + ΔtT_imp_to_Δu_imp_tuples, |
| 92 | + prev_ΔU_imp_to_Δu_imp_tuples, |
| 93 | + ) |
| 94 | + |
| 95 | + newtons_method_cache = if is_accessible(ΔU_imp_values_sparse) |
| 96 | + isnothing(newtons_method) && imp_error(name) |
| 97 | + j = has_jac(T_imp!) ? T_imp!.jac_prototype : nothing |
| 98 | + allocate_cache(newtons_method, u0, j) |
| 99 | + else |
| 100 | + nothing |
| 101 | + end |
| 102 | + |
| 103 | + return IMEXARKCache(timestepper_cache, newtons_method_cache) |
| 104 | +end |
| 105 | + |
| 106 | +function step_u!(integrator, cache::IMEXARKCache) |
| 107 | + (; u, p, t, alg) = integrator |
| 108 | + (; T_lim!, T_exp!, T_exp_T_lim!, T_imp!) = integrator.sol.prob.f |
| 109 | + (; lim!, dss!, post_explicit!, post_implicit!) = integrator.sol.prob.f |
| 110 | + (; name, newtons_method) = alg |
| 111 | + (; a_imp, c_exp, c_imp) = alg.tableau |
| 112 | + (; newtons_method_cache) = cache |
| 113 | + (; |
| 114 | + sdirk_γ, |
| 115 | + temp_value1, |
| 116 | + temp_value2, |
| 117 | + T_lim_values_sparse, |
| 118 | + T_exp_values_sparse, |
| 119 | + T_imp_values_sparse, |
| 120 | + ΔU_imp_values_sparse, |
| 121 | + T_lim_values, |
| 122 | + T_exp_values, |
| 123 | + T_imp_values, |
| 124 | + ΔU_imp_values, |
| 125 | + ΔtT_lim_to_Δu_lim_tuples, |
| 126 | + ΔtT_exp_to_Δu_exp_tuples, |
| 127 | + ΔtT_imp_to_Δu_imp_tuples, |
| 128 | + prev_ΔU_imp_to_Δu_imp_tuples, |
| 129 | + ) = cache.timestepper_cache |
| 130 | + |
| 131 | + Δt = integrator.dt |
| 132 | + s = size(a_imp, 1) |
| 133 | + |
| 134 | + if !isnothing(newtons_method_cache) |
| 135 | + (; update_j) = newtons_method |
| 136 | + (; j) = newtons_method_cache |
| 137 | + if !isnothing(j) && needs_update!(update_j, NewTimeStep(t)) |
| 138 | + isnothing(sdirk_γ) && sdirk_error(name) |
| 139 | + T_imp!.Wfact(j, u, p, Δt * sdirk_γ, t) |
| 140 | + end |
| 141 | + end |
| 142 | + |
| 143 | + unrolled_foreach( |
| 144 | + ntuple(identity, s), |
| 145 | + T_lim_values, |
| 146 | + T_exp_values, |
| 147 | + T_imp_values, |
| 148 | + ΔU_imp_values, |
| 149 | + ΔtT_lim_to_Δu_lim_tuples, |
| 150 | + ΔtT_exp_to_Δu_exp_tuples, |
| 151 | + ΔtT_imp_to_Δu_imp_tuples, |
| 152 | + prev_ΔU_imp_to_Δu_imp_tuples, |
| 153 | + ) do ( |
| 154 | + stage, |
| 155 | + T_lim, |
| 156 | + T_exp, |
| 157 | + T_imp, |
| 158 | + ΔU_imp, |
| 159 | + ΔtT_lim_to_ΔU_lim, |
| 160 | + ΔtT_exp_to_ΔU_exp, |
| 161 | + ΔtT_imp_to_ΔU_imp, |
| 162 | + prev_ΔU_imp_to_ΔU_imp, |
| 163 | + ) |
| 164 | + t_exp = t + Δt * c_exp[stage] |
| 165 | + t_imp = t + Δt * c_imp[stage] |
| 166 | + Δtγ = Δt * a_imp[stage, stage] |
| 167 | + |
| 168 | + ΔU_lim_over_Δt = sparse_broadcasted_dot(ΔtT_lim_to_ΔU_lim, T_lim_values_sparse) |
| 169 | + ΔU_exp_over_Δt = sparse_broadcasted_dot(ΔtT_exp_to_ΔU_exp, T_exp_values_sparse) |
| 170 | + ΔU_imp_from_T_imp_over_Δt = sparse_broadcasted_dot(ΔtT_imp_to_ΔU_imp, T_imp_values_sparse) |
| 171 | + ΔU_imp_from_prev_ΔU_imp = sparse_broadcasted_dot(prev_ΔU_imp_to_ΔU_imp, ΔU_imp_values_sparse) |
| 172 | + |
| 173 | + if is_accessible(ΔtT_lim_to_ΔU_lim) |
| 174 | + u_plus_ΔU_lim = temp_value1 |
| 175 | + @. u_plus_ΔU_lim = u + Δt * ΔU_lim_over_Δt |
| 176 | + lim!(u_plus_ΔU_lim, p, t_exp, u) |
| 177 | + else |
| 178 | + u_plus_ΔU_lim = u |
| 179 | + end |
| 180 | + |
| 181 | + if is_accessible(ΔtT_exp_to_ΔU_exp) || is_accessible(ΔtT_imp_to_ΔU_imp) || is_accessible(prev_ΔU_imp_to_ΔU_imp) |
| 182 | + U_before_solve = temp_value2 |
| 183 | + @. U_before_solve = |
| 184 | + u_plus_ΔU_lim + Δt * (ΔU_exp_over_Δt + ΔU_imp_from_T_imp_over_Δt) + ΔU_imp_from_prev_ΔU_imp |
| 185 | + else |
| 186 | + U_before_solve = u_plus_ΔU_lim |
| 187 | + end |
| 188 | + |
| 189 | + is_not_u_before_solve = |
| 190 | + is_accessible(ΔtT_lim_to_ΔU_lim) || |
| 191 | + is_accessible(ΔtT_exp_to_ΔU_exp) || |
| 192 | + is_accessible(ΔtT_imp_to_ΔU_imp) || |
| 193 | + is_accessible(prev_ΔU_imp_to_ΔU_imp) |
| 194 | + |
| 195 | + # TODO: Rename post_explicit! to pre_newton_iteration!, and rename |
| 196 | + # post_implicit! to post_stage!. Make pre_newton_iteration! only set |
| 197 | + # precomputed quantities needed by T_imp! and T_imp!.Wfact. Keep |
| 198 | + # post_stage! as it is now, so that it sets all precomputed quantities. |
| 199 | + |
| 200 | + if !isnothing(ΔU_imp) |
| 201 | + is_not_u_before_solve && post_explicit!(U_before_solve, p, t_imp) |
| 202 | + |
| 203 | + U = ΔU_imp # Use ΔU_imp as additional temporary storage. |
| 204 | + @. U = U_before_solve |
| 205 | + |
| 206 | + # Solve U ≈ U_before_solve + Δtγ * T_imp(U, p, t_imp) for U. |
| 207 | + solve_newton!( |
| 208 | + newtons_method, |
| 209 | + newtons_method_cache, |
| 210 | + U, |
| 211 | + (residual, U) -> begin |
| 212 | + T_imp!(residual, U, p, t_imp) |
| 213 | + @. residual = U_before_solve - U + Δtγ * residual |
| 214 | + end, |
| 215 | + (j, U) -> T_imp!.Wfact(j, U, p, Δtγ, t_imp), # j = ∂residual/∂U |
| 216 | + U -> post_explicit!(U, p, t_imp), |
| 217 | + U -> post_implicit!(U, p, t_imp), |
| 218 | + ) |
| 219 | + else |
| 220 | + U = U_before_solve # There is no solve on this stage. |
| 221 | + is_not_u_before_solve && post_implicit!(U, p, t_imp) |
| 222 | + end |
| 223 | + |
| 224 | + if !isnothing(T_lim) || !isnothing(T_exp) |
| 225 | + if !isnothing(T_exp_T_lim!) |
| 226 | + T_exp_T_lim!(T_exp, T_lim, U, p, t_exp) |
| 227 | + if stage != s |
| 228 | + # TODO: Fuse these two DSS calls into one. |
| 229 | + dss!(T_lim, p, t_exp) |
| 230 | + dss!(T_exp, p, t_exp) |
| 231 | + end |
| 232 | + end |
| 233 | + # TODO: Drop support for specifying T_lim! separately from T_exp!. |
| 234 | + if !isnothing(T_lim!) |
| 235 | + T_lim!(T_lim, U, p, t_exp) |
| 236 | + stage != s && dss!(T_lim, p, t_exp) |
| 237 | + end |
| 238 | + if !isnothing(T_exp!) |
| 239 | + T_exp!(T_exp, U, p, t_exp) |
| 240 | + stage != s && dss!(T_exp, p, t_exp) |
| 241 | + end |
| 242 | + end |
| 243 | + if !isnothing(T_imp) |
| 244 | + T_imp!(T_imp, U, p, t_imp) |
| 245 | + end |
| 246 | + if !isnothing(ΔU_imp) |
| 247 | + @. ΔU_imp = U - u_plus_ΔU_lim - Δt * ΔU_exp_over_Δt |
| 248 | + # = U - U_before_solve + Δt * ΔU_imp_from_T_imp_over_Δt + |
| 249 | + # ΔU_imp_from_prev_ΔU_imp |
| 250 | + end |
| 251 | + end |
| 252 | + |
| 253 | + t_final = t + Δt |
| 254 | + ΔtT_lim_to_final_Δu_lim = ΔtT_lim_to_Δu_lim_tuples[s + 1] |
| 255 | + ΔtT_exp_to_final_Δu_exp = ΔtT_exp_to_Δu_exp_tuples[s + 1] |
| 256 | + ΔtT_imp_to_final_Δu_imp = ΔtT_imp_to_Δu_imp_tuples[s + 1] |
| 257 | + ΔU_imp_to_final_Δu_imp = prev_ΔU_imp_to_Δu_imp_tuples[s + 1] |
| 258 | + final_Δu_lim_over_Δt = sparse_broadcasted_dot(ΔtT_lim_to_final_Δu_lim, T_lim_values_sparse) |
| 259 | + final_Δu_exp_over_Δt = sparse_broadcasted_dot(ΔtT_exp_to_final_Δu_exp, T_exp_values_sparse) |
| 260 | + final_Δu_imp_from_T_imp_over_Δt = sparse_broadcasted_dot(ΔtT_imp_to_final_Δu_imp, T_imp_values_sparse) |
| 261 | + final_Δu_imp_from_ΔU_imp = sparse_broadcasted_dot(ΔU_imp_to_final_Δu_imp, ΔU_imp_values_sparse) |
| 262 | + |
| 263 | + if is_accessible(ΔtT_lim_to_final_Δu_lim) |
| 264 | + final_u_plus_Δu_lim = temp_value |
| 265 | + @. final_u_plus_Δu_lim = u + Δt * final_Δu_lim_over_Δt |
| 266 | + lim!(final_u_plus_Δu_lim, p, t_final, u) |
| 267 | + else |
| 268 | + final_u_plus_Δu_lim = u |
| 269 | + end |
| 270 | + |
| 271 | + @. u = |
| 272 | + final_u_plus_Δu_lim + Δt * (final_Δu_exp_over_Δt + final_Δu_imp_from_T_imp_over_Δt) + final_Δu_imp_from_ΔU_imp |
| 273 | + dss!(u, p, t_final) |
| 274 | + post_implicit!(u, p, t_final) |
| 275 | + |
| 276 | + return u |
| 277 | +end |
0 commit comments