Skip to content

Commit ae81958

Browse files
committed
Fix FIRK nested error
1 parent c1d4036 commit ae81958

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

lib/BoundaryValueDiffEqFIRK/src/firk.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,7 @@ function init_nested(prob::BVProblem, alg::AbstractFIRK; dt = 0.0,
162162

163163
prob_ = !(prob.u0 isa AbstractArray) ? remake(prob; u0 = X) : prob
164164

165-
if isa(prob.u0, AbstractArray) && eltype(prob.u0) <: AbstractVector
166-
u0_mat = hcat(prob.u0...)
167-
avg_u0 = vec(sum(u0_mat, dims = 2)) / size(u0_mat, 2)
168-
else
169-
avg_u0 = prob.u0
170-
end
171-
172-
K0 = repeat(avg_u0, 1, stage) # Somewhat arbitrary initialization of K
165+
K0 = __K0_on_u0(prob.u0, stage) # Somewhat arbitrary initialization of K
173166

174167
nestprob_p = zeros(T, M + 2)
175168
nest_tol = alg.nest_tol

lib/BoundaryValueDiffEqFIRK/src/utils.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,20 @@ function BoundaryValueDiffEqCore.__append_similar!(
3939
append!(x, VectorOfArray([similar(last(x)) for _ in 1:N]))
4040
return x
4141
end
42+
43+
@inline __K0_on_u0(u0::AbstractArray, stage) = repeat(u0, 1, stage)
44+
@inline function __K0_on_u0(u0::AbstractVector{<:AbstractArray}, stage)
45+
u0_mat = hcat(u0...)
46+
avg_u0 = vec(sum(u0_mat, dims = 2)) / size(u0_mat, 2)
47+
return repeat(avg_u0, 1, stage)
48+
end
49+
@inline function __K0_on_u0(u0::AbstractVectorOfArray, stage)
50+
u0_mat = hcat(u0.u...)
51+
avg_u0 = vec(sum(u0_mat, dims = 2)) / size(u0_mat, 2)
52+
return repeat(avg_u0, 1, stage)
53+
end
54+
@inline function __K0_on_u0(u0::SciMLBase.ODESolution, stage)
55+
u0_mat = hcat(u0.u...)
56+
avg_u0 = vec(sum(u0_mat, dims = 2)) / size(u0_mat, 2)
57+
return repeat(avg_u0, 1, stage)
58+
end

0 commit comments

Comments
 (0)