5454
5555 for j in idx
5656 z = similar (cache. fᵢ₂_cache)
57- interp_eval ! (z, id. cache, tvals[j], id. cache. mesh, id. cache. mesh_dt)
57+ interpolant ! (z, id. cache, tvals[j], id. cache. mesh, id. cache. mesh_dt, deriv )
5858 vals[j] = idxs != = nothing ? z[idxs] : z
5959 end
6060 return DiffEqArray (vals, tvals)
6868
6969 for j in idx
7070 z = similar (cache. fᵢ₂_cache)
71- interp_eval ! (z, id. cache, tvals[j], id. cache. mesh, id. cache. mesh_dt)
71+ interpolant ! (z, id. cache, tvals[j], id. cache. mesh, id. cache. mesh_dt, deriv )
7272 vals[j] = z
7373 end
7474end
7575
7676@inline function interpolation (tval:: Number , id:: FIRKNestedInterpolation , idxs,
7777 deriv:: D , p, continuity:: Symbol = :left ) where {D}
7878 z = similar (id. cache. fᵢ₂_cache)
79- interp_eval ! (z, id. cache, tval, id. cache. mesh, id. cache. mesh_dt)
79+ interpolant ! (z, id. cache, tval, id. cache. mesh, id. cache. mesh_dt, deriv )
8080 return idxs != = nothing ? z[idxs] : z
8181end
8282
83+ @inline function interpolant! (z:: AbstractArray , cache:: FIRKCacheNested{iip, T} ,
84+ t, mesh, mesh_dt, :: Type{Val{0}} ) where {iip, T}
85+ (; f, ITU, nest_prob, alg) = cache
86+ (; q_coeff) = ITU
87+
88+ j = interval (mesh, t)
89+ h = mesh_dt[j]
90+ lf = (length (cache. y₀) - 1 ) / (length (cache. y) - 1 )
91+ if lf > 1
92+ h *= lf
93+ end
94+ τ = (t - mesh[j])
95+
96+ nest_nlsolve_alg = __concrete_nonlinearsolve_algorithm (nest_prob, alg. nlsolve)
97+ nestprob_p = zeros (T, cache. M + 2 )
98+
99+ yᵢ = copy (cache. y[j]. du)
100+ yᵢ₊₁ = copy (cache. y[j + 1 ]. du)
101+
102+ if iip
103+ dyᵢ = similar (yᵢ)
104+ dyᵢ₊₁ = similar (yᵢ₊₁)
105+
106+ f (dyᵢ, yᵢ, cache. p, mesh[j])
107+ f (dyᵢ₊₁, yᵢ₊₁, cache. p, mesh[j + 1 ])
108+ else
109+ dyᵢ = f (yᵢ, cache. p, mesh[j])
110+ dyᵢ₊₁ = f (yᵢ₊₁, cache. p, mesh[j + 1 ])
111+ end
112+
113+ nestprob_p[1 ] = mesh[j]
114+ nestprob_p[2 ] = mesh_dt[j]
115+ nestprob_p[3 : end ] .= yᵢ
116+
117+ _nestprob = remake (nest_prob, p = nestprob_p)
118+ nestsol = __solve (_nestprob, nest_nlsolve_alg; alg. nested_nlsolve_kwargs... )
119+ K = nestsol. u
120+
121+ z₁, z₁′ = eval_q (yᵢ, 0.5 , h, q_coeff, K) # Evaluate q(x) at midpoints
122+ S_coeffs = get_S_coeffs (h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)
123+
124+ S_interpolate! (z, τ, S_coeffs)
125+ end
126+
127+ @inline function interpolant! (dz:: AbstractArray , cache:: FIRKCacheNested{iip, T} ,
128+ t, mesh, mesh_dt, :: Type{Val{1}} ) where {iip, T}
129+ (; f, ITU, nest_prob, alg) = cache
130+ (; q_coeff) = ITU
131+
132+ j = interval (mesh, t)
133+ h = mesh_dt[j]
134+ lf = (length (cache. y₀) - 1 ) / (length (cache. y) - 1 )
135+ if lf > 1
136+ h *= lf
137+ end
138+ τ = (t - mesh[j])
139+
140+ nest_nlsolve_alg = __concrete_nonlinearsolve_algorithm (nest_prob, alg. nlsolve)
141+ nestprob_p = zeros (T, cache. M + 2 )
142+
143+ yᵢ = copy (cache. y[j]. du)
144+ yᵢ₊₁ = copy (cache. y[j + 1 ]. du)
145+
146+ if iip
147+ dyᵢ = similar (yᵢ)
148+ dyᵢ₊₁ = similar (yᵢ₊₁)
149+
150+ f (dyᵢ, yᵢ, cache. p, mesh[j])
151+ f (dyᵢ₊₁, yᵢ₊₁, cache. p, mesh[j + 1 ])
152+ else
153+ dyᵢ = f (yᵢ, cache. p, mesh[j])
154+ dyᵢ₊₁ = f (yᵢ₊₁, cache. p, mesh[j + 1 ])
155+ end
156+
157+ nestprob_p[1 ] = mesh[j]
158+ nestprob_p[2 ] = mesh_dt[j]
159+ nestprob_p[3 : end ] .= yᵢ
160+
161+ _nestprob = remake (nest_prob, p = nestprob_p)
162+ nestsol = __solve (_nestprob, nest_nlsolve_alg; alg. nested_nlsolve_kwargs... )
163+ K = nestsol. u
164+
165+ z₁, z₁′ = eval_q (yᵢ, 0.5 , h, q_coeff, K)
166+ S_coeffs = get_S_coeffs (h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)
167+
168+ dS_interpolate! (dz, τ, S_coeffs)
169+ end
170+
83171# # Expanded
84172@inline function interpolation (tvals, id:: FIRKExpandInterpolation , idxs,
85173 deriv:: D , p, continuity:: Symbol = :left ) where {D}
97185
98186 for j in idx
99187 z = similar (cache. fᵢ₂_cache)
100- interp_eval ! (z, id. cache, tvals[j], id. cache. mesh, id. cache. mesh_dt)
188+ interpolant ! (z, id. cache, tvals[j], id. cache. mesh, id. cache. mesh_dt, deriv )
101189 vals[j] = idxs != = nothing ? z[idxs] : z
102190 end
103191 return DiffEqArray (vals, tvals)
@@ -111,18 +199,102 @@ end
111199
112200 for j in idx
113201 z = similar (cache. fᵢ₂_cache)
114- interp_eval ! (z, id. cache, tvals[j], id. cache. mesh, id. cache. mesh_dt)
202+ interpolant ! (z, id. cache, tvals[j], id. cache. mesh, id. cache. mesh_dt, deriv )
115203 vals[j] = z
116204 end
117205end
118206
119207@inline function interpolation (tval:: Number , id:: FIRKExpandInterpolation , idxs,
120208 deriv:: D , p, continuity:: Symbol = :left ) where {D}
121209 z = similar (id. cache. fᵢ₂_cache)
122- interp_eval ! (z, id. cache, tval, id. cache. mesh, id. cache. mesh_dt)
210+ interpolant ! (z, id. cache, tval, id. cache. mesh, id. cache. mesh_dt, deriv )
123211 return idxs != = nothing ? z[idxs] : z
124212end
125213
214+ @inline function interpolant! (z:: AbstractArray , cache:: FIRKCacheExpand{iip} ,
215+ t, mesh, mesh_dt, :: Type{Val{0}} ) where {iip}
216+ j = interval (mesh, t)
217+ h = mesh_dt[j]
218+ lf = (length (cache. y₀) - 1 ) / (length (cache. y) - 1 )
219+ if lf > 1
220+ h *= lf
221+ end
222+ τ = (t - mesh[j])
223+
224+ (; f, M, stage, p, ITU) = cache
225+ (; q_coeff) = ITU
226+
227+ K = safe_similar (cache. y[1 ]. du, M, stage)
228+
229+ ctr_y = (j - 1 ) * (stage + 1 ) + 1
230+
231+ yᵢ = cache. y[ctr_y]. du
232+ yᵢ₊₁ = cache. y[ctr_y + stage + 1 ]. du
233+
234+ if iip
235+ dyᵢ = similar (yᵢ)
236+ dyᵢ₊₁ = similar (yᵢ₊₁)
237+
238+ f (dyᵢ, yᵢ, p, mesh[j])
239+ f (dyᵢ₊₁, yᵢ₊₁, p, mesh[j + 1 ])
240+ else
241+ dyᵢ = f (yᵢ, p, mesh[j])
242+ dyᵢ₊₁ = f (yᵢ₊₁, p, mesh[j + 1 ])
243+ end
244+
245+ # Load interpolation residual
246+ for jj in 1 : stage
247+ K[:, jj] = cache. y[ctr_y + jj]. du
248+ end
249+
250+ z₁, z₁′ = eval_q (yᵢ, 0.5 , h, q_coeff, K) # Evaluate q(x) at midpoints
251+ S_coeffs = get_S_coeffs (h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)
252+
253+ S_interpolate! (z, τ, S_coeffs)
254+ end
255+
256+ @inline function interpolant! (dz:: AbstractArray , cache:: FIRKCacheExpand{iip} ,
257+ t, mesh, mesh_dt, :: Type{Val{1}} ) where {iip}
258+ j = interval (mesh, t)
259+ h = mesh_dt[j]
260+ lf = (length (cache. y₀) - 1 ) / (length (cache. y) - 1 )
261+ if lf > 1
262+ h *= lf
263+ end
264+ τ = (t - mesh[j])
265+
266+ (; f, M, stage, p, ITU) = cache
267+ (; q_coeff) = ITU
268+
269+ K = safe_similar (cache. y[1 ]. du, M, stage)
270+
271+ ctr_y = (j - 1 ) * (stage + 1 ) + 1
272+
273+ yᵢ = cache. y[ctr_y]. du
274+ yᵢ₊₁ = cache. y[ctr_y + stage + 1 ]. du
275+
276+ if iip
277+ dyᵢ = similar (yᵢ)
278+ dyᵢ₊₁ = similar (yᵢ₊₁)
279+
280+ f (dyᵢ, yᵢ, p, mesh[j])
281+ f (dyᵢ₊₁, yᵢ₊₁, p, mesh[j + 1 ])
282+ else
283+ dyᵢ = f (yᵢ, p, mesh[j])
284+ dyᵢ₊₁ = f (yᵢ₊₁, p, mesh[j + 1 ])
285+ end
286+
287+ # Load interpolation residual
288+ for jj in 1 : stage
289+ K[:, jj] = cache. y[ctr_y + jj]. du
290+ end
291+
292+ z₁, z₁′ = eval_q (yᵢ, 0.5 , h, q_coeff, K) # Evaluate q(x) at midpoints
293+ S_coeffs = get_S_coeffs (h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)
294+
295+ dS_interpolate! (dz, τ, S_coeffs)
296+ end
297+
126298@inline __build_interpolation (cache:: FIRKCacheExpand , u:: AbstractVector ) = FIRKExpandInterpolation (
127299 cache. mesh, u, cache)
128300@inline __build_interpolation (cache:: FIRKCacheNested , u:: AbstractVector ) = FIRKNestedInterpolation (
0 commit comments