Skip to content

Commit 36c353f

Browse files
committed
Fix expanded FIRK with Enzyme
1 parent a5c9d8d commit 36c353f

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

lib/BoundaryValueDiffEqFIRK/src/adaptivity.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ end
212212
Form the quartic interpolation constraint matrix, see bvp5c paper.
213213
"""
214214
function s_constraints(M, h)
215-
t = vec(repeat([0.0, 1.0 * h, 0.5 * h, 0.0, 1.0 * h, 0.5 * h], 1, M))
215+
t = repeat([0.0, 1.0 * h, 0.5 * h, 0.0, 1.0 * h, 0.5 * h], M)
216216
A = zeros(6 * M, 6 * M)
217217
for i in 1:6
218218
row_start = (i - 1) * M + 1

lib/BoundaryValueDiffEqFIRK/src/interpolation.jl

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ end
133133
# Expanded FIRK
134134
function (s::EvalSol{C})(tval::Number) where {C <: FIRKCacheExpand}
135135
(; t, u, cache) = s
136-
(; f, alg, ITU, p) = cache
136+
(; f, alg, ITU, mesh_dt, p) = cache
137137
(; q_coeff) = ITU
138138
stage = alg_stage(alg)
139139
# Quick handle for the case where tval is at the boundary
@@ -161,18 +161,67 @@ function (s::EvalSol{C})(tval::Number) where {C <: FIRKCacheExpand}
161161
for jj in 1:stage
162162
K[:, jj] = u[ctr_y + jj]
163163
end
164-
h = t[j + 1] - t[j]
164+
h = mesh_dt(j)
165165
τ = tval - t[j]
166166

167-
z₁, z₁′ = eval_q(yᵢ, 0.5, h, q_coeff, K) # Evaluate q(x) at midpoints
168-
S_coeffs = get_S_coeffs(h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)
167+
M = size(K, 1)
168+
z₁, z₁′ = similar(yᵢ), similar(yᵢ₊₁)
169+
for i in 1:M
170+
ki = @view K[i, :]
171+
coeffs = get_q_coeffs_interp(q_coeff, ki, h)
172+
z₁[i] = yᵢ[i] + sum(coeffs[ii] ** h)^(ii) for ii in axes(coeffs, 1))
173+
z₁′[i] = sum(ii * coeffs[ii] ** h)^(ii - 1) for ii in axes(coeffs, 1))
174+
end
175+
176+
S_coeffs = get_S_coeffs_interp(h, yᵢ, yᵢ₊₁, z₁, dyᵢ, dyᵢ₊₁, z₁′)
169177

170178
z = similar(yᵢ)
171179

172180
S_interpolate!(z, τ, S_coeffs)
173181
return z
174182
end
175183

184+
function get_S_coeffs_interp(h, yᵢ, yᵢ₊₁, dyᵢ, dyᵢ₊₁, ymid, dymid)
185+
vals = vcat(yᵢ, yᵢ₊₁, dyᵢ, dyᵢ₊₁, ymid, dymid)
186+
M = length(yᵢ)
187+
A = s_constraints_interp(M, h)
188+
coeffs = reshape(A \ vals, 6, M)'
189+
return coeffs
190+
end
191+
192+
function get_q_coeffs_interp(A, ki, h)
193+
coeffs = A * ki
194+
for i in axes(coeffs, 1)
195+
coeffs[i] = coeffs[i] / (h^(i - 1))
196+
end
197+
return coeffs
198+
end
199+
200+
function s_constraints_interp(M, h)
201+
t = repeat([0.0, 1.0 * h, 0.5 * h, 0.0, 1.0 * h, 0.5 * h], M)
202+
A = zeros(6 * M, 6 * M)
203+
204+
for i in 1:6
205+
row_start = (i - 1) * M + 1
206+
for k in 0:(M - 1)
207+
for j in 1:6
208+
A[row_start + k, j + k * 6] = t[i + k * 6]^(j - 1)
209+
end
210+
end
211+
end
212+
for i in 4:6
213+
row_start = (i - 1) * M + 1
214+
for k in 0:(M - 1)
215+
for j in 1:6
216+
A[row_start + k, j + k * 6] = j == 1.0 ? 0.0 :
217+
(j - 1) * t[i + k * 6]^(j - 2)
218+
end
219+
end
220+
end
221+
222+
return A
223+
end
224+
176225
# Nested FIRK
177226
function (s::EvalSol{C})(tval::Number) where {C <: FIRKCacheNested}
178227
(; t, u, cache) = s

0 commit comments

Comments
 (0)