Skip to content

Commit 5ff4ced

Browse files
committed
Fix CI
1 parent 36c23c2 commit 5ff4ced

File tree

5 files changed

+18
-17
lines changed

5 files changed

+18
-17
lines changed

lib/BoundaryValueDiffEqFIRK/src/firk.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -719,7 +719,8 @@ end
719719
return nothing
720720
end
721721

722-
@views function __firk_loss!(resid, u, p, y::AbstractVector, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2},
722+
@views function __firk_loss!(
723+
resid, u, p, y::AbstractVector, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2},
723724
residual, mesh, cache, _, trait::DiffCacheNeeded) where {BC1, BC2}
724725
y_ = recursive_unflatten!(y, u)
725726
resids = [get_tmp(r, u) for r in residual]
@@ -752,8 +753,8 @@ end
752753
return vcat(resid_bc, mapreduce(vec, vcat, resid_co))
753754
end
754755

755-
@views function __firk_loss(u, p, y::AbstractVector, pt::TwoPointBVProblem, bc::Tuple{BC1, BC2},
756-
mesh, cache, _, trait) where {BC1, BC2}
756+
@views function __firk_loss(u, p, y::AbstractVector, pt::TwoPointBVProblem,
757+
bc::Tuple{BC1, BC2}, mesh, cache, _, trait) where {BC1, BC2}
757758
y_ = recursive_unflatten!(y, u)
758759
soly_ = VectorOfArray(y_)
759760
resid_bca, resid_bcb = eval_bc_residual(pt, bc, y_, p, mesh)

lib/BoundaryValueDiffEqMIRK/src/adaptivity.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ end
129129

130130
@views function mesh_selector!(
131131
cache::MIRKCache{iip, T}, controller::SequentialErrorControl) where {iip, T}
132-
(; order, errors, TU, mesh, mesh_dt) = cache
132+
(; order, errors, mesh, mesh_dt) = cache
133133
(abstol, _, _), _ = __split_kwargs(; cache.kwargs...)
134134
N = length(mesh)
135135

@@ -183,7 +183,7 @@ end
183183

184184
@views function mesh_selector!(
185185
cache::MIRKCache{iip, T}, controller::HybridErrorControl) where {iip, T}
186-
(; order, errors, TU, mesh, mesh_dt) = cache
186+
(; order, errors, mesh, mesh_dt) = cache
187187
(abstol, _, _), _ = __split_kwargs(; cache.kwargs...)
188188
N = length(mesh)
189189

lib/BoundaryValueDiffEqMIRK/src/collocation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ end
9191
@views function Φ(fᵢ_cache, k_discrete, f, TU::MIRKTableau, y, u, p,
9292
mesh, mesh_dt, stage::Int, ::NoDiffCacheNeeded)
9393
(; c, v, x, b) = TU
94-
residuals = [__similar(yᵢ) for yᵢ in y[1:(end - 1)]]
94+
residuals = [safe_similar(yᵢ) for yᵢ in y[1:(end - 1)]]
9595
tmp = similar(fᵢ_cache)
9696
T = eltype(u)
9797
for i in eachindex(k_discrete)

lib/BoundaryValueDiffEqMIRK/src/interpolation.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ end
8282
end
8383

8484
function sum_stages!(z::AbstractArray, id::MIRKInterpolation,
85-
cache::MIRKCache{iip, T, use_both, DiffCacheNeeded}, w,
86-
i::Int) where {iip, T, use_both}
85+
cache::MIRKCache{iip, T, use_both, DiffCacheNeeded},
86+
w, i::Int) where {iip, T, use_both}
8787
(; stage, k_discrete, k_interp) = cache
8888
(; s_star) = cache.ITU
8989
dt = cache.mesh_dt[i]
@@ -96,10 +96,11 @@ function sum_stages!(z::AbstractArray, id::MIRKInterpolation,
9696
return z
9797
end
9898
function sum_stages!(z::AbstractArray, id::MIRKInterpolation,
99-
cache::MIRKCache{iip, T, use_both, NoDiffCacheNeeded}, w,
100-
i::Int, dt = cache.mesh_dt[i]) where {iip, T, use_both}
99+
cache::MIRKCache{iip, T, use_both, NoDiffCacheNeeded},
100+
w, i::Int) where {iip, T, use_both}
101101
(; stage, k_discrete, k_interp) = cache
102102
(; s_star) = cache.ITU
103+
dt = cache.mesh_dt[i]
103104
z .= zero(z)
104105
__maybe_matmul!(z, k_discrete[i][:, 1:stage], w[1:stage])
105106
__maybe_matmul!(
@@ -159,7 +160,7 @@ function (s::EvalSol{C})(tval::Number) where {C <: MIRKCache}
159160
(tval == t[end]) && return last(u)
160161
z = zero(last(u))
161162
ii = interval(t, tval)
162-
dt = t[ii + 1] - t[ii]
163+
dt = cache.mesh_dt[ii]
163164
τ = (tval - t[ii]) / dt
164165
w, _ = evalsol_interp_weights(τ, alg)
165166
K = __needs_diffcache(alg.jac_alg) ? k_discrete[ii].du[:, 1:stage] :

lib/BoundaryValueDiffEqMIRK/src/mirk.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0, abstol =
103103
vecf, vecbc
104104
end
105105

106-
prob_ = !(prob.u0 isa AbstractArray) ? remake(prob; u0 = X) : prob
106+
#prob_ = !(prob.u0 isa AbstractArray) ? remake(prob; u0 = X) : prob
107107

108108
return MIRKCache{iip, T, use_both, typeof(diffcache)}(
109-
alg_order(alg), stage, N, size(X), f, bc, prob_, prob.problem_type,
109+
alg_order(alg), stage, N, size(X), f, bc, prob, prob.problem_type,
110110
prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt, k_discrete,
111111
k_interp, y, y₀, residual, fᵢ_cache, fᵢ₂_cache, errors, new_stages,
112112
resid₁_size, (; abstol, dt, adaptive, controller, kwargs...))
@@ -156,11 +156,11 @@ function SciMLBase.solve!(cache::MIRKCache)
156156
end
157157

158158
function __perform_mirk_iteration(cache::MIRKCache, abstol, adaptive::Bool,
159-
controller; nlsolve_kwargs = (;), kwargs...)
159+
controller::AbstractErrorControl; nlsolve_kwargs = (;), kwargs...)
160160
nlprob = __construct_nlproblem(cache, vec(cache.y₀), copy(cache.y₀))
161161
nlsolve_alg = __concrete_nonlinearsolve_algorithm(nlprob, cache.alg.nlsolve)
162162
sol_nlprob = __solve(
163-
nlprob, nlsolve_alg; abstol, kwargs..., nlsolve_kwargs..., alias_u0 = true)
163+
nlprob, nlsolve_alg; abstol = abstol, kwargs..., nlsolve_kwargs..., alias_u0 = true)
164164
recursive_unflatten!(cache.y₀, sol_nlprob.u)
165165

166166
error_norm = 2 * abstol
@@ -280,10 +280,9 @@ end
280280
residual, mesh, cache, _, trait::NoDiffCacheNeeded) where {BC1, BC2}
281281
y_ = recursive_unflatten!(y, u)
282282
Φ!(residual[2:end], cache, y_, u, trait)
283-
soly_ = VectorOfArray(y_)
284283
resida = residual[1][1:prod(cache.resid_size[1])]
285284
residb = residual[1][(prod(cache.resid_size[1]) + 1):end]
286-
eval_bc_residual!((resida, residb), pt, bc!, soly_, p, mesh)
285+
eval_bc_residual!((resida, residb), pt, bc!, y_, p, mesh)
287286
recursive_flatten_twopoint!(resid, residual, cache.resid_size)
288287
return nothing
289288
end

0 commit comments

Comments
 (0)