@@ -7,7 +7,7 @@ After we construct an interpolant, we use interp_eval to evaluate it.
77 i = interval (mesh, t)
88 dt = mesh_dt[i]
99 τ = (t - mesh[i]) / dt
10- w, w′ = interp_weights (τ, cache. alg)
10+ w, _ = interp_weights (τ, cache. alg)
1111 sum_stages! (y, cache, w, i, dt)
1212 return y
1313end
@@ -626,32 +626,35 @@ function sum_stages!(cache::MIRKCache{iip, T, use_both, NoDiffCacheNeeded}, w,
626626 sum_stages! (cache. fᵢ_cache, cache. fᵢ₂_cache, cache, w, w′, i, dt)
627627end
628628
629+ # Here we should not directly in-place change z in several steps
630+ # because in final step we actually need to use the original z(which is cache.y₀.u[i])
631+ # we use fᵢ₂_cache to avoid additional allocations.
629632function sum_stages! (z:: AbstractArray , cache:: MIRKCache{iip, T, use_both, DiffCacheNeeded} ,
630633 w, i:: Int , dt = cache. mesh_dt[i]) where {iip, T, use_both}
631- (; stage, k_discrete, k_interp) = cache
634+ (; stage, k_discrete, k_interp, fᵢ₂_cache ) = cache
632635 (; s_star) = cache. ITU
633636
634- z .= zero (z)
635- __maybe_matmul! (z , k_discrete[i]. du[:, 1 : stage], w[1 : stage])
637+ fᵢ₂_cache .= zero (z)
638+ __maybe_matmul! (fᵢ₂_cache , k_discrete[i]. du[:, 1 : stage], w[1 : stage])
636639 __maybe_matmul! (
637- z , k_interp. u[i][:, 1 : (s_star - stage)], w[(stage + 1 ): s_star], true , true )
638- z .= z .* dt .+ cache. y₀. u[i]
640+ fᵢ₂_cache , k_interp. u[i][:, 1 : (s_star - stage)], w[(stage + 1 ): s_star], true , true )
641+ z .= fᵢ₂_cache .* dt .+ cache. y₀. u[i]
639642
640- return z
643+ return nothing
641644end
642645function sum_stages! (
643646 z:: AbstractArray , cache:: MIRKCache{iip, T, use_both, NoDiffCacheNeeded} ,
644647 w, i:: Int , dt = cache. mesh_dt[i]) where {iip, T, use_both}
645- (; stage, k_discrete, k_interp) = cache
648+ (; stage, k_discrete, k_interp, fᵢ₂_cache ) = cache
646649 (; s_star) = cache. ITU
647650
648- z .= zero (z)
649- __maybe_matmul! (z , k_discrete[i][:, 1 : stage], w[1 : stage])
651+ fᵢ₂_cache .= zero (z)
652+ __maybe_matmul! (fᵢ₂_cache , k_discrete[i][:, 1 : stage], w[1 : stage])
650653 __maybe_matmul! (
651- z , k_interp. u[i][:, 1 : (s_star - stage)], w[(stage + 1 ): s_star], true , true )
652- z .= z .* dt .+ cache. y₀. u[i]
654+ fᵢ₂_cache , k_interp. u[i][:, 1 : (s_star - stage)], w[(stage + 1 ): s_star], true , true )
655+ z .= fᵢ₂_cache .* dt .+ cache. y₀. u[i]
653656
654- return z
657+ return nothing
655658end
656659
657660@views function sum_stages! (z, z′, cache:: MIRKCache{iip, T, use_both, DiffCacheNeeded} , w,
0 commit comments