@@ -50,15 +50,16 @@ function init_cache(
5050) where {uType, tType}
5151
5252 tab = tableau (alg, eltype (prob. u0))
53- Nstages = length (tab. B) # TODO : create function for this
53+ (; B, Aimpl) = tab
54+ Nstages = length (B) # TODO : create function for this
5455 U = zero (prob. u0)
5556 L = ntuple (i -> zero (prob. u0), Nstages)
5657 R = ntuple (i -> zero (prob. u0), Nstages)
5758
5859 if prob. f isa DiffEqBase. ODEFunction
59- W = EulerOperator (prob. f. jvp, - dt * tab . Aimpl[2 , 2 ], prob. p, prob. tspan[1 ])
60+ W = EulerOperator (prob. f. jvp, - dt * Aimpl[2 , 2 ], prob. p, prob. tspan[1 ])
6061 elseif prob. f isa DiffEqBase. SplitFunction
61- W = EulerOperator (prob. f. f1, - dt * tab . Aimpl[2 , 2 ], prob. p, prob. tspan[1 ])
62+ W = EulerOperator (prob. f. f1, - dt * Aimpl[2 , 2 ], prob. p, prob. tspan[1 ])
6263 end
6364 linsolve! = alg. linsolve (Val{:init }, W, prob. u0; kwargs... )
6465
@@ -108,95 +109,88 @@ function step_u!(int, cache::AdditiveRungeKuttaFullCache{Nstages}) where {Nstage
108109end
109110
110111function step_u! (int, cache:: AdditiveRungeKuttaFullCache{Nstages} , f:: DiffEqBase.SplitFunction ) where {Nstages}
111- tab = cache. tableau
112- U = cache. U
113- Uhat = cache. R[end ] # can be used as work array, as only used in last stage
114112
113+ (; C, Aimpl, Aexpl, B) = cache. tableau
114+ (; U, R, L, W, linsolve) = cache
115+ (; u, p, t, dt) = int
115116
116- u = int. u
117- p = int. p
118- t = int. t
119- dt = int. dt
117+ U = U
118+ Uhat = R[end ] # can be used as work array, as only used in last stage
120119
121120 fL! = f. f1 # linear part
122121 fR! = f. f2 # remainder
123122
124123 # first stage is always explicit
125- τ = t + tab . C[1 ] * dt
124+ τ = t + C[1 ] * dt
126125
127- # cache. L[i] .= fL(cache. U[i-1], p, t + tab. C[i]*dt)
128- # cache. R[1] .= fR(cache. U[i-1], p, t + tab. C[i]*dt)
126+ # L[i] .= fL(U[i-1], p, t + C[i]*dt)
127+ # R[1] .= fR(U[i-1], p, t + C[i]*dt)
129128
130- fL! (cache . L[1 ], u, p, τ)
131- fR! (cache . R[1 ], u, p, τ)
129+ fL! (L[1 ], u, p, τ)
130+ fR! (R[1 ], u, p, τ)
132131
133132 for i in 2 : Nstages
134133 # solve for W * U = Uhat
135134 # set U to initial guess based on fully explicit
136135 # TODO : we don't need this for direct solves
137136 U .= Uhat .= u
138137 for j in 1 : (i - 1 )
139- Uhat .+ = (dt * tab . Aimpl[i, j]) .* cache . L[j] .+ (dt * tab . Aexpl[i, j]) .* cache . R[j]
138+ Uhat .+ = (dt * Aimpl[i, j]) .* L[j] .+ (dt * Aexpl[i, j]) .* R[j]
140139 # initial value: we only need to do this if using an iterative method:
141- U .+ = (dt * tab . Aexpl[i, j]) .* (cache . L[j] .+ cache . R[j])
140+ U .+ = (dt * Aexpl[i, j]) .* (L[j] .+ R[j])
142141 end
143142
144143 # W = I - dt * Aimpl[i,i] * L
145144 # currently only use SDIRK methods where
146145 # Aimpl[i,i] = i == 1 ? 0 : const
147146 # TODO : handle changing dt & Aimpl[i,i] coeffs
148- if ! (DiffEqBase. isconstant (cache . W))
149- cache . W. t = τ
147+ if ! (DiffEqBase. isconstant (W))
148+ W. t = τ
150149 W_updated = true
151150 end
152- γ = - dt * tab . Aimpl[i, i]
153- if cache . W. γ != γ
154- cache . W. γ = γ
151+ γ = - dt * Aimpl[i, i]
152+ if W. γ != γ
153+ W. γ = γ
155154 W_updated = true
156155 end
157- cache . linsolve! (U, cache . W, Uhat, W_updated)
156+ linsolve! (U, W, Uhat, W_updated)
158157
159- τ = t + tab . C[i] * dt
160- fL! (cache . L[i], U, p, τ)
161- # or use cache. L[i] .= (U .- Uhat) ./ (dt * tab. Aimpl[i,i]) ?
162- fR! (cache . R[i], U, p, τ)
158+ τ = t + C[i] * dt
159+ fL! (L[i], U, p, τ)
160+ # or use L[i] .= (U .- Uhat) ./ (dt * Aimpl[i,i]) ?
161+ fR! (R[i], U, p, τ)
163162 end
164163
165164 # compute next step
166165 for i in 1 : Nstages
167- u .+ = (dt * tab . B[i]) .* (cache . L[i] .+ cache . R[i])
166+ u .+ = (dt * B[i]) .* (L[i] .+ R[i])
168167 end
169168end
170169
171170# WIP
172171function step_u! (int, cache:: AdditiveRungeKuttaFullCache{Nstages} , f:: DiffEqBase.ODEFunction ) where {Nstages}
173- tab = cache. tableau
174- U = cache. U
175- Uhat = cache. R[end ] # can be used as work array, as only used in last stage
176-
177-
178- u = int. u
179- p = int. p
180- t = int. t
181- dt = int. dt
172+ (; C, Aimpl, Aexpl, B) = cache. tableau
173+ (; u, p, t, dt) = int
174+ (; U, R, L, W, linsolve) = cache
175+ Uhat = R[end ] # can be used as work array, as only used in last stage
182176
183177 fL! = f. jvp # linear part
184178 f! = f. f # remainder
185179
186180 # first stage is always explicit
187- τ = t + tab . C[1 ] * dt
181+ τ = t + C[1 ] * dt
188182
189- f! (cache . R[1 ], u, p, τ)
190- fL! (cache . L[1 ], u, p, τ)
183+ f! (R[1 ], u, p, τ)
184+ fL! (L[1 ], u, p, τ)
191185 for i in 2 : Nstages
192186 # solve for W * U = Uhat
193187 # set U to initial guess based on fully explicit
194188 # TODO : we don't need this for direct solves
195189 U .= Uhat .= u
196190 for j in 1 : (i - 1 )
197- Uhat .+ = (dt * tab . Aimpl[i, j]) .* cache . L[j] .+ (dt * tab . Aexpl[i, j]) .* (cache . R[j] .- cache . L[j])
191+ Uhat .+ = (dt * Aimpl[i, j]) .* L[j] .+ (dt * Aexpl[i, j]) .* (R[j] .- L[j])
198192 # initial value: we only need to do this if using an iterative method:
199- U .+ = (dt * tab . Aexpl[i, j]) .* cache . R[j]
193+ U .+ = (dt * Aexpl[i, j]) .* R[j]
200194 end
201195 #=
202196 # solve for W * U = Uhat
@@ -205,36 +199,36 @@ function step_u!(int, cache::AdditiveRungeKuttaFullCache{Nstages}, f::DiffEqBase
205199 V .= u
206200 Ω .= 0
207201 for j = 1:i-1
208- V .+= dt * tab. Aexpl[i,j] .* cache. R[j]
209- Ω .+= (tab. Aimpl[i,j]-tab. Aexpl[i,j])/tab. Aimpl[i,i] .* cache. U[j]
202+ V .+= dt * Aexpl[i,j] .* R[j]
203+ Ω .+= (Aimpl[i,j]-Aexpl[i,j])/Aimpl[i,i] .* U[j]
210204 end
211205 Vhat .= V .+ Ω
212206 =#
213207 # W = I - dt * Aimpl[i,i] * L
214208 # currently only use SDIRK methods where
215209 # Aimpl[i,i] = i == 1 ? 0 : const
216210 # TODO : handle changing dt & Aimpl[i,i] coeffs
217- if ! (DiffEqBase. isconstant (cache . W))
218- cache . W. t = τ
211+ if ! (DiffEqBase. isconstant (W))
212+ W. t = τ
219213 W_updated = true
220214 end
221- γ = - dt * tab . Aimpl[i, i]
222- if cache . W. γ != γ
223- cache . W. γ = γ
215+ γ = - dt * Aimpl[i, i]
216+ if W. γ != γ
217+ W. γ = γ
224218 W_updated = true
225219 end
226- cache . linsolve! (U, cache . W, Uhat, W_updated)
220+ linsolve! (U, W, Uhat, W_updated)
227221 # U = V .- Ω
228222
229- τ = t + tab . C[i] * dt
230- fL! (cache . L[i], U, p, τ)
231- # or use cache. L[i] .= (U .- Uhat) ./ (dt * tab. Aimpl[i,i]) ?
232- f! (cache . R[i], U, p, τ)
223+ τ = t + C[i] * dt
224+ fL! (L[i], U, p, τ)
225+ # or use L[i] .= (U .- Uhat) ./ (dt * Aimpl[i,i]) ?
226+ f! (R[i], U, p, τ)
233227 end
234228
235229 # compute next step
236230 for i in 1 : Nstages
237- u .+ = (dt * tab . B[i]) .* cache . R[i]
231+ u .+ = (dt * B[i]) .* R[i]
238232 end
239233end
240234
0 commit comments