Skip to content

Commit 837de33

Browse files
authored
Merge pull request #331 from SciML/qqy/simplify_interp
2 parents bda144a + 7475f94 commit 837de33

File tree

3 files changed

+36
-47
lines changed

3 files changed

+36
-47
lines changed

lib/BoundaryValueDiffEqCore/src/utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ end
175175

176176
function eval_bc_residual!(
177177
resid, ::StandardSecondOrderBVProblem, bc!::BC, sol, dsol, p, t) where {BC}
178-
M = length(sol[1])
179178
bc!(resid, dsol, sol, p, t)
180179
end
181180

lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Generate new mesh based on the defect or the global error.
2525
(; order, errors, mesh, mesh_dt) = cache
2626
(abstol, _, _), _ = __split_kwargs(; cache.kwargs...)
2727
N = length(mesh)
28+
n = N - 1
2829

2930
safety_factor = T(1.3)
3031
ρ = T(1.0)
@@ -41,13 +42,12 @@ Generate new mesh based on the defect or the global error.
4142
r₃ = r₂ / (N - 1)
4243

4344
n_predict = round(Int, (safety_factor * r₂) + 1)
44-
n = N - 1
4545
n_ = T(0.1) * n
4646
n_predict = ifelse(abs((n_predict - n)) < n_, round(Int, n + n_), n_predict)
4747

4848
if r₁ ρ * r₃
49-
Nsub_star = 2 * (N - 1)
50-
if Nsub_star > cache.alg.max_num_subintervals # Need to determine the too large threshold
49+
Nsub_star = 2 * n
50+
if Nsub_star > cache.alg.max_num_subintervals
5151
info = ReturnCode.Failure
5252
meshₒ = mesh
5353
mesh_dt₀ = mesh_dt
@@ -78,6 +78,7 @@ end
7878
(; order, errors, mesh, mesh_dt) = cache
7979
(abstol, _, _), _ = __split_kwargs(; cache.kwargs...)
8080
N = length(mesh)
81+
n = N - 1
8182

8283
safety_factor = T(1.3)
8384
ρ = T(2.0)
@@ -91,15 +92,14 @@ end
9192
ŝ .= (ŝ ./ abstol) .^ (T(1) / order)
9293
r₁ = maximum(ŝ)
9394
r₂ = sum(ŝ)
94-
r₃ = r₂ / (N - 1)
95+
r₃ = r₂ / n
9596

9697
n_predict = round(Int, (safety_factor * r₂) + 1)
97-
n = N - 1
9898
n_ = T(0.1) * n
9999
n_predict = ifelse(abs((n_predict - n)) < n_, round(Int, n + n_), n_predict)
100100

101101
if r₁ ρ * r₃
102-
Nsub_star = 2 * (N - 1)
102+
Nsub_star = 2 * n
103103
# Need to determine the too large threshold
104104
if Nsub_star > cache.alg.max_num_subintervals
105105
info = ReturnCode.Failure
@@ -132,6 +132,7 @@ end
132132
(; order, errors, mesh, mesh_dt) = cache
133133
(abstol, _, _), _ = __split_kwargs(; cache.kwargs...)
134134
N = length(mesh)
135+
n = N - 1
135136

136137
safety_factor = T(1.3)
137138
ρ = T(2.0)
@@ -145,15 +146,15 @@ end
145146
ŝ .= (ŝ ./ abstol) .^ (T(1) / (order + 1))
146147
r₁ = maximum(ŝ)
147148
r₂ = sum(ŝ)
148-
r₃ = r₂ / (N - 1)
149+
r₃ = r₂ / n
149150

150151
n_predict = round(Int, (safety_factor * r₂) + 1)
151152
n = N - 1
152153
n_ = T(0.1) * n
153154
n_predict = ifelse(abs((n_predict - n)) < n_, round(Int, n + n_), n_predict)
154155

155156
if r₁ ρ * r₃
156-
Nsub_star = 2 * (N - 1)
157+
Nsub_star = 2 * n
157158
# Need to determine the too large threshold
158159
if Nsub_star > cache.alg.max_num_subintervals
159160
info = ReturnCode.Failure
@@ -186,30 +187,30 @@ end
186187
(; order, errors, mesh, mesh_dt) = cache
187188
(abstol, _, _), _ = __split_kwargs(; cache.kwargs...)
188189
N = length(mesh)
190+
n = N - 1
189191

190192
safety_factor = T(1.3)
191193
ρ = T(2.0)
192194
Nsub_star = 0
193-
Nsub_star_ub = 4 * (N - 1)
195+
Nsub_star_ub = 4 * n
194196
Nsub_star_lb = N ÷ 2
195197

196198
info = ReturnCode.Success
197199

198-
ŝ₁ = [maximum(abs, d) for d in errors.u[1:(N - 1)]]
200+
ŝ₁ = [maximum(abs, d) for d in errors.u[1:n]]
199201
ŝ₂ = [maximum(abs, d) for d in errors.u[N:end]]
200202
= similar(ŝ₁)
201203
ŝ .= (ŝ₁ ./ abstol) .^ (T(1) / (order + 1)) + (ŝ₂ ./ abstol) .^ (T(1) / (order + 1))
202204
r₁ = maximum(ŝ)
203205
r₂ = sum(ŝ)
204-
r₃ = r₂ / (N - 1)
206+
r₃ = r₂ / n
205207

206208
n_predict = round(Int, (safety_factor * r₂) + 1)
207-
n = N - 1
208209
n_ = T(0.1) * n
209210
n_predict = ifelse(abs((n_predict - n)) < n_, round(Int, n + n_), n_predict)
210211

211212
if r₁ ρ * r₃
212-
Nsub_star = 2 * (N - 1)
213+
Nsub_star = 2 * n
213214
# Need to determine the too large threshold
214215
if Nsub_star > cache.alg.max_num_subintervals
215216
info = ReturnCode.Failure
@@ -627,7 +628,8 @@ end
627628
# Here we should not directly in-place change z in several steps
628629
# because in final step we actually need to use the original z(which is cache.y₀.u[i])
629630
# we use fᵢ₂_cache to avoid additional allocations.
630-
function sum_stages!(z::AbstractArray, cache::MIRKCache{iip, T, use_both, DiffCacheNeeded},
631+
@views function sum_stages!(
632+
z::AbstractArray, cache::MIRKCache{iip, T, use_both, DiffCacheNeeded},
631633
w, i::Int, dt = cache.mesh_dt[i]) where {iip, T, use_both}
632634
(; stage, k_discrete, k_interp, fᵢ₂_cache) = cache
633635
(; s_star) = cache.ITU
@@ -640,7 +642,7 @@ function sum_stages!(z::AbstractArray, cache::MIRKCache{iip, T, use_both, DiffCa
640642

641643
return nothing
642644
end
643-
function sum_stages!(
645+
@views function sum_stages!(
644646
z::AbstractArray, cache::MIRKCache{iip, T, use_both, NoDiffCacheNeeded},
645647
w, i::Int, dt = cache.mesh_dt[i]) where {iip, T, use_both}
646648
(; stage, k_discrete, k_interp, fᵢ₂_cache) = cache

lib/BoundaryValueDiffEqMIRK/src/interpolation.jl

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -67,23 +67,22 @@ end
6767
i = interval(mesh, t)
6868
dt = mesh_dt[i]
6969
τ = (t - mesh[i]) / dt
70-
w, w′ = interp_weights(τ, cache.alg)
71-
sum_stages!(z, id, cache, w, i)
70+
w, _ = interp_weights(τ, cache.alg)
71+
sum_stages!(z, id, cache, w, i, T)
7272
end
7373

7474
@inline function interpolant!(dz::AbstractArray, id::MIRKInterpolation,
7575
cache::MIRKCache, t, mesh, mesh_dt, T::Type{Val{1}})
7676
i = interval(mesh, t)
7777
dt = mesh_dt[i]
7878
τ = (t - mesh[i]) / dt
79-
w, w′ = interp_weights(τ, cache.alg)
80-
z = similar(dz)
81-
sum_stages!(z, dz, id, cache, w, w′, i)
79+
_, w′ = interp_weights(τ, cache.alg)
80+
sum_stages!(dz, id, cache, w′, i, T)
8281
end
8382

84-
function sum_stages!(z::AbstractArray, id::MIRKInterpolation,
83+
@views function sum_stages!(z::AbstractArray, id::MIRKInterpolation,
8584
cache::MIRKCache{iip, T, use_both, DiffCacheNeeded},
86-
w, i::Int) where {iip, T, use_both}
85+
w, i::Int, ::Type{Val{0}}) where {iip, T, use_both}
8786
(; stage, k_discrete, k_interp) = cache
8887
(; s_star) = cache.ITU
8988
dt = cache.mesh_dt[i]
@@ -93,11 +92,11 @@ function sum_stages!(z::AbstractArray, id::MIRKInterpolation,
9392
z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true)
9493
z .= z .* dt .+ id.u[i]
9594

96-
return z
95+
return nothing
9796
end
98-
function sum_stages!(z::AbstractArray, id::MIRKInterpolation,
97+
@views function sum_stages!(z::AbstractArray, id::MIRKInterpolation,
9998
cache::MIRKCache{iip, T, use_both, NoDiffCacheNeeded},
100-
w, i::Int) where {iip, T, use_both}
99+
w, i::Int, ::Type{Val{0}}) where {iip, T, use_both}
101100
(; stage, k_discrete, k_interp) = cache
102101
(; s_star) = cache.ITU
103102
dt = cache.mesh_dt[i]
@@ -107,44 +106,33 @@ function sum_stages!(z::AbstractArray, id::MIRKInterpolation,
107106
z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true)
108107
z .= z .* dt .+ id.u[i]
109108

110-
return z
109+
return nothing
111110
end
112111

113112
@views function sum_stages!(
114-
z, z′, id::MIRKInterpolation, cache::MIRKCache{iip, T, use_both, DiffCacheNeeded},
115-
w, w′, i::Int) where {iip, T, use_both}
113+
z′, id::MIRKInterpolation, cache::MIRKCache{iip, T, use_both, DiffCacheNeeded},
114+
w′, i::Int, ::Type{Val{1}}) where {iip, T, use_both}
116115
(; stage, k_discrete, k_interp) = cache
117116
(; s_star) = cache.ITU
118-
dt = cache.mesh_dt[i]
119-
z .= zero(z)
120-
__maybe_matmul!(z, k_discrete[i].du[:, 1:stage], w[1:stage])
121-
__maybe_matmul!(
122-
z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true)
123117
z′ .= zero(z′)
124118
__maybe_matmul!(z′, k_discrete[i].du[:, 1:stage], w′[1:stage])
125119
__maybe_matmul!(
126120
z′, k_interp.u[i][:, 1:(s_star - stage)], w′[(stage + 1):s_star], true, true)
127-
z .= z .* dt[1] .+ id.u[i]
128121

129-
return z, z′
122+
return nothing
130123
end
131124
@views function sum_stages!(
132-
z, z′, id::MIRKInterpolation, cache::MIRKCache{iip, T, use_both, NoDiffCacheNeeded},
133-
w, w′, i::Int, dt = cache.mesh_dt[i]) where {iip, T, use_both}
125+
z′, id::MIRKInterpolation, cache::MIRKCache{iip, T, use_both, NoDiffCacheNeeded},
126+
w′, i::Int, ::Type{Val{1}}) where {iip, T, use_both}
134127
(; stage, k_discrete, k_interp) = cache
135128
(; s_star) = cache.ITU
136129

137-
z .= zero(z)
138-
__maybe_matmul!(z, k_discrete[i][:, 1:stage], w[1:stage])
139-
__maybe_matmul!(
140-
z, k_interp.u[i][:, 1:(s_star - stage)], w[(stage + 1):s_star], true, true)
141130
z′ .= zero(z′)
142131
__maybe_matmul!(z′, k_discrete[i][:, 1:stage], w′[1:stage])
143132
__maybe_matmul!(
144133
z′, k_interp.u[i][:, 1:(s_star - stage)], w′[(stage + 1):s_star], true, true)
145-
z .= z .* dt[1] .+ id.u[i]
146134

147-
return z, z′
135+
return nothing
148136
end
149137

150138
@inline __build_interpolation(cache::MIRKCache, u::AbstractVector) = MIRKInterpolation(
@@ -163,9 +151,9 @@ function (s::EvalSol{C})(tval::Number) where {C <: MIRKCache}
163151
dt = cache.mesh_dt[ii]
164152
τ = (tval - t[ii]) / dt
165153
w, _ = evalsol_interp_weights(τ, alg)
166-
K = __needs_diffcache(alg.jac_alg) ? k_discrete[ii].du[:, 1:stage] :
167-
k_discrete[ii][:, 1:stage]
168-
__maybe_matmul!(z, K, w[1:stage])
154+
K = __needs_diffcache(alg.jac_alg) ? @view(k_discrete[ii].du[:, 1:stage]) :
155+
@view(k_discrete[ii][:, 1:stage])
156+
__maybe_matmul!(z, K, @view(w[1:stage]))
169157
z .= z .* dt .+ u[ii]
170158
return z
171159
end

0 commit comments

Comments
 (0)