6767 i = interval (mesh, t)
6868 dt = mesh_dt[i]
6969 τ = (t - mesh[i]) / dt
70- w, w′ = interp_weights (τ, cache. alg)
71- sum_stages! (z, id, cache, w, i)
70+ w, _ = interp_weights (τ, cache. alg)
71+ sum_stages! (z, id, cache, w, i, T )
7272end
7373
7474@inline function interpolant! (dz:: AbstractArray , id:: MIRKInterpolation ,
7575 cache:: MIRKCache , t, mesh, mesh_dt, T:: Type{Val{1}} )
7676 i = interval (mesh, t)
7777 dt = mesh_dt[i]
7878 τ = (t - mesh[i]) / dt
79- w, w′ = interp_weights (τ, cache. alg)
80- z = similar (dz)
81- sum_stages! (z, dz, id, cache, w, w′, i)
79+ _, w′ = interp_weights (τ, cache. alg)
80+ sum_stages! (dz, id, cache, w′, i, T)
8281end
8382
84- function sum_stages! (z:: AbstractArray , id:: MIRKInterpolation ,
83+ @views function sum_stages! (z:: AbstractArray , id:: MIRKInterpolation ,
8584 cache:: MIRKCache{iip, T, use_both, DiffCacheNeeded} ,
86- w, i:: Int ) where {iip, T, use_both}
85+ w, i:: Int , :: Type{Val{0}} ) where {iip, T, use_both}
8786 (; stage, k_discrete, k_interp) = cache
8887 (; s_star) = cache. ITU
8988 dt = cache. mesh_dt[i]
@@ -93,11 +92,11 @@ function sum_stages!(z::AbstractArray, id::MIRKInterpolation,
9392 z, k_interp. u[i][:, 1 : (s_star - stage)], w[(stage + 1 ): s_star], true , true )
9493 z .= z .* dt .+ id. u[i]
9594
96- return z
95+ return nothing
9796end
98- function sum_stages! (z:: AbstractArray , id:: MIRKInterpolation ,
97+ @views function sum_stages! (z:: AbstractArray , id:: MIRKInterpolation ,
9998 cache:: MIRKCache{iip, T, use_both, NoDiffCacheNeeded} ,
100- w, i:: Int ) where {iip, T, use_both}
99+ w, i:: Int , :: Type{Val{0}} ) where {iip, T, use_both}
101100 (; stage, k_discrete, k_interp) = cache
102101 (; s_star) = cache. ITU
103102 dt = cache. mesh_dt[i]
@@ -107,44 +106,33 @@ function sum_stages!(z::AbstractArray, id::MIRKInterpolation,
107106 z, k_interp. u[i][:, 1 : (s_star - stage)], w[(stage + 1 ): s_star], true , true )
108107 z .= z .* dt .+ id. u[i]
109108
110- return z
109+ return nothing
111110end
112111
113112@views function sum_stages! (
114- z, z ′, id:: MIRKInterpolation , cache:: MIRKCache{iip, T, use_both, DiffCacheNeeded} ,
115- w, w ′, i:: Int ) where {iip, T, use_both}
113+ z′, id:: MIRKInterpolation , cache:: MIRKCache{iip, T, use_both, DiffCacheNeeded} ,
114+ w′, i:: Int , :: Type{Val{1}} ) where {iip, T, use_both}
116115 (; stage, k_discrete, k_interp) = cache
117116 (; s_star) = cache. ITU
118- dt = cache. mesh_dt[i]
119- z .= zero (z)
120- __maybe_matmul! (z, k_discrete[i]. du[:, 1 : stage], w[1 : stage])
121- __maybe_matmul! (
122- z, k_interp. u[i][:, 1 : (s_star - stage)], w[(stage + 1 ): s_star], true , true )
123117 z′ .= zero (z′)
124118 __maybe_matmul! (z′, k_discrete[i]. du[:, 1 : stage], w′[1 : stage])
125119 __maybe_matmul! (
126120 z′, k_interp. u[i][:, 1 : (s_star - stage)], w′[(stage + 1 ): s_star], true , true )
127- z .= z .* dt[1 ] .+ id. u[i]
128121
129- return z, z′
122+ return nothing
130123end
131124@views function sum_stages! (
132- z, z ′, id:: MIRKInterpolation , cache:: MIRKCache{iip, T, use_both, NoDiffCacheNeeded} ,
133- w, w ′, i:: Int , dt = cache . mesh_dt[i] ) where {iip, T, use_both}
125+ z′, id:: MIRKInterpolation , cache:: MIRKCache{iip, T, use_both, NoDiffCacheNeeded} ,
126+ w′, i:: Int , :: Type{Val{1}} ) where {iip, T, use_both}
134127 (; stage, k_discrete, k_interp) = cache
135128 (; s_star) = cache. ITU
136129
137- z .= zero (z)
138- __maybe_matmul! (z, k_discrete[i][:, 1 : stage], w[1 : stage])
139- __maybe_matmul! (
140- z, k_interp. u[i][:, 1 : (s_star - stage)], w[(stage + 1 ): s_star], true , true )
141130 z′ .= zero (z′)
142131 __maybe_matmul! (z′, k_discrete[i][:, 1 : stage], w′[1 : stage])
143132 __maybe_matmul! (
144133 z′, k_interp. u[i][:, 1 : (s_star - stage)], w′[(stage + 1 ): s_star], true , true )
145- z .= z .* dt[1 ] .+ id. u[i]
146134
147- return z, z′
135+ return nothing
148136end
149137
150138@inline __build_interpolation (cache:: MIRKCache , u:: AbstractVector ) = MIRKInterpolation (
@@ -163,9 +151,9 @@ function (s::EvalSol{C})(tval::Number) where {C <: MIRKCache}
163151 dt = cache. mesh_dt[ii]
164152 τ = (tval - t[ii]) / dt
165153 w, _ = evalsol_interp_weights (τ, alg)
166- K = __needs_diffcache (alg. jac_alg) ? k_discrete[ii]. du[:, 1 : stage] :
167- k_discrete[ii][:, 1 : stage]
168- __maybe_matmul! (z, K, w[1 : stage])
154+ K = __needs_diffcache (alg. jac_alg) ? @view ( k_discrete[ii]. du[:, 1 : stage]) :
155+ @view ( k_discrete[ii][:, 1 : stage])
156+ __maybe_matmul! (z, K, @view ( w[1 : stage]) )
169157 z .= z .* dt .+ u[ii]
170158 return z
171159end
0 commit comments