Skip to content

Commit 7c3246c

Browse files
authored
Merge pull request #286 from ErikQQY/qqy/fix_dt
2 parents 6918da0 + f865b05 commit 7c3246c

File tree

5 files changed

+59
-13
lines changed

5 files changed

+59
-13
lines changed

lib/BoundaryValueDiffEqCore/src/utils.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,6 @@ end
207207
function __extract_problem_details(
208208
prob, u0::SciMLBase.ODESolution; dt = 0.0, check_positive_dt::Bool = false)
209209
# Problem passes in a initial guess function
210-
check_positive_dt && dt 0 && throw(ArgumentError("dt must be positive"))
211210
_u0 = first(u0.u)
212211
_t = u0.t
213212
return Val(true), eltype(_u0), length(_u0), (length(_t) - 1), _u0

lib/BoundaryValueDiffEqFIRK/src/firk.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,14 +162,7 @@ function init_nested(prob::BVProblem, alg::AbstractFIRK; dt = 0.0,
162162

163163
prob_ = !(prob.u0 isa AbstractArray) ? remake(prob; u0 = X) : prob
164164

165-
if isa(prob.u0, AbstractArray) && eltype(prob.u0) <: AbstractVector
166-
u0_mat = hcat(prob.u0...)
167-
avg_u0 = vec(sum(u0_mat, dims = 2)) / size(u0_mat, 2)
168-
else
169-
avg_u0 = prob.u0
170-
end
171-
172-
K0 = repeat(avg_u0, 1, stage) # Somewhat arbitrary initialization of K
165+
K0 = __K0_on_u0(prob.u0, stage) # Somewhat arbitrary initialization of K
173166

174167
nestprob_p = zeros(T, M + 2)
175168
nest_tol = alg.nest_tol

lib/BoundaryValueDiffEqFIRK/src/utils.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,20 @@ function BoundaryValueDiffEqCore.__append_similar!(
3939
append!(x, VectorOfArray([similar(last(x)) for _ in 1:N]))
4040
return x
4141
end
42+
43+
@inline __K0_on_u0(u0::AbstractArray, stage) = repeat(u0, 1, stage)
44+
@inline function __K0_on_u0(u0::AbstractVector{<:AbstractArray}, stage)
45+
u0_mat = hcat(u0...)
46+
avg_u0 = vec(sum(u0_mat, dims = 2)) / size(u0_mat, 2)
47+
return repeat(avg_u0, 1, stage)
48+
end
49+
@inline function __K0_on_u0(u0::AbstractVectorOfArray, stage)
50+
u0_mat = hcat(u0.u...)
51+
avg_u0 = vec(sum(u0_mat, dims = 2)) / size(u0_mat, 2)
52+
return repeat(avg_u0, 1, stage)
53+
end
54+
@inline function __K0_on_u0(u0::SciMLBase.ODESolution, stage)
55+
u0_mat = hcat(u0.u...)
56+
avg_u0 = vec(sum(u0_mat, dims = 2)) / size(u0_mat, 2)
57+
return repeat(avg_u0, 1, stage)
58+
end

lib/BoundaryValueDiffEqFIRK/test/nested/firk_basic_tests.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -437,11 +437,9 @@ end =#
437437

438438
bvp4 = TwoPointBVProblem(simplependulum!, (bc2a!, bc2b!), sol3, (0, pi / 2),
439439
pi / 2; bcresid_prototype = (zeros(1), zeros(1)))
440-
@test_broken SciMLBase.successful_retcode(solve(
441-
bvp4, RadauIIa5(; nested_nlsolve = true), dt = 0.05))
440+
SciMLBase.successful_retcode(solve(bvp4, RadauIIa5(; nested_nlsolve = true), dt = 0.05))
442441

443442
bvp5 = TwoPointBVProblem(simplependulum!, (bc2a!, bc2b!), DiffEqArray(sol3.u, sol3.t),
444443
(0, pi / 2), pi / 2; bcresid_prototype = (zeros(1), zeros(1)))
445-
@test_broken SciMLBase.successful_retcode(solve(
446-
bvp5, RadauIIa5(; nested_nlsolve = true), dt = 0.05))
444+
SciMLBase.successful_retcode(solve(bvp5, RadauIIa5(; nested_nlsolve = true), dt = 0.05))
447445
end

test/misc/initial_guess_tests.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
@testitem "Initial guess providing" begin
2+
using BoundaryValueDiffEq, RecursiveArrayTools
3+
tspan = (0.0, pi / 2)
4+
function simplependulum!(du, u, p, t)
5+
θ = u[1]
6+
= u[2]
7+
du[1] =
8+
du[2] = -9.81 * sin(θ)
9+
end
10+
function bc!(residual, u, p, t)
11+
residual[1] = u(pi / 4)[1] + pi / 2
12+
residual[2] = u(pi / 2)[1] - pi / 2
13+
end
14+
u0 = [pi / 2, pi / 2]
15+
prob = BVProblem(simplependulum!, bc!, u0, tspan)
16+
sol1 = solve(prob, MIRK4(), dt = 0.05)
17+
18+
# Solution
19+
prob1 = BVProblem(simplependulum!, bc!, sol1, tspan)
20+
sol2 = solve(prob1, MIRK4())
21+
@test SciMLBase.successful_retcode(sol2)
22+
23+
sol3 = solve(prob1, RadauIIa5())
24+
@test SciMLBase.successful_retcode(sol3)
25+
26+
sol4 = solve(prob1, LobattoIIIa4(nested_nlsolve = true))
27+
@test SciMLBase.successful_retcode(sol4)
28+
29+
# VectorOfArray
30+
prob2 = BVProblem(simplependulum!, bc!, VectorOfArray(sol1.u), tspan)
31+
sol2 = solve(prob2, MIRK4())
32+
@test SciMLBase.successful_retcode(sol2)
33+
34+
sol3 = solve(prob2, RadauIIa5())
35+
@test SciMLBase.successful_retcode(sol3)
36+
37+
sol4 = solve(prob2, LobattoIIIa4(nested_nlsolve = true))
38+
@test SciMLBase.successful_retcode(sol4)
39+
end

0 commit comments

Comments
 (0)