Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[weakdeps]
ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
BVPInterface = "349b9a96-74d8-4082-8332-dd34ffd0a01f"

[extensions]
BoundaryValueDiffEqODEInterfaceExt = "ODEInterface"
BoundaryValueDiffEqBVPInterfaceExt = "BVPInterface"

[compat]
ADTypes = "1.11"
Expand All @@ -42,7 +42,7 @@ Hwloc = "3.3"
InteractiveUtils = "<0.0.1, 1"
JET = "0.9.12"
LinearAlgebra = "1.10"
ODEInterface = "0.5"
BVPInterface = "0.1.0"
OrdinaryDiffEq = "6.90.1"
Pkg = "1.10.0"
Random = "1.10"
Expand All @@ -61,7 +61,7 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
NonlinearSolveFirstOrder = "5959db7a-ea39-4486-b5fe-2dd0bf03d60d"
ODEInterface = "54ca160b-1b9f-5127-a996-1867f4bc2a2c"
BVPInterface = "349b9a96-74d8-4082-8332-dd34ffd0a01f"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -71,4 +71,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "DiffEqDevTools", "Hwloc", "InteractiveUtils", "JET", "LinearSolve", "NonlinearSolveFirstOrder", "ODEInterface", "OrdinaryDiffEq", "Pkg", "Random", "ReTestItems", "RecursiveArrayTools", "StaticArrays", "Test"]
test = ["Aqua", "DiffEqDevTools", "Hwloc", "InteractiveUtils", "JET", "LinearSolve", "NonlinearSolveFirstOrder", "BVPInterface", "OrdinaryDiffEq", "Pkg", "Random", "ReTestItems", "RecursiveArrayTools", "StaticArrays", "Test"]
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
module BoundaryValueDiffEqODEInterfaceExt
module BoundaryValueDiffEqBVPInterfaceExt

using BoundaryValueDiffEq: BVPM2, BVPSOL, COLNEW
using BoundaryValueDiffEq: BVPM2, BVPSOL, COLSYS, COLNEW
using BoundaryValueDiffEqCore: BoundaryValueDiffEqAlgorithm, __extract_u0,
__initial_guess_length, __extract_mesh,
__flatten_initial_guess, __get_bcresid_prototype,
__has_initial_guess, __initial_guess
using SciMLBase: SciMLBase, BVProblem, TwoPointBVProblem, ReturnCode
using ODEInterface: OptionsODE, OPT_ATOL, OPT_RTOL, OPT_METHODCHOICE, OPT_DIAGNOSTICOUTPUT,
using BVPInterface: OptionsODE, OPT_ATOL, OPT_RTOL, OPT_METHODCHOICE, OPT_DIAGNOSTICOUTPUT,
OPT_ERRORCONTROL, OPT_SINGULARTERM, OPT_MAXSTEPS, OPT_BVPCLASS,
OPT_SOLMETHOD, OPT_RHS_CALLMODE, OPT_COLLOCATIONPTS, OPT_ADDGRIDPOINTS,
OPT_MAXSUBINTERVALS, RHS_CALL_INSITU, evalSolution
using ODEInterface: Bvpm2, bvpm2_init, bvpm2_solve, bvpm2_destroy, bvpm2_get_x
using ODEInterface: bvpsol
using ODEInterface: colnew
using BVPInterface: Bvpm2, bvpm2_init, bvpm2_solve, bvpm2_destroy, bvpm2_get_x
using BVPInterface: bvpsol
using BVPInterface: colsys
using BVPInterface: colnew

using FastClosures: @closure
using ForwardDiff: ForwardDiff
Expand Down Expand Up @@ -339,6 +340,168 @@ function SciMLBase.__solve(prob::BVProblem, alg::COLNEW; maxiters = 1000,
stats = destats, original = (sol, retcode, stats))
end

#-------
# COLSYS
#-------
function SciMLBase.__solve(prob::BVProblem, alg::COLSYS; maxiters = 1000,
reltol = 1e-3, dt = 0.0, verbose = true, kwargs...)
dt ≤ 0 && throw(ArgumentError("`dt` must be positive"))

t₀, t₁ = prob.tspan
u0_ = __extract_u0(prob.u0, prob.p, t₀)
u0_size = size(u0_)
n = __initial_guess_length(prob.u0)

u0 = __flatten_initial_guess(prob.u0)
mesh = __extract_mesh(prob.u0, t₀, t₁, ifelse(n == -1, dt, n - 1))
if u0 === nothing
# initial_guess function was provided
u0 = mapreduce(@closure(t->vec(__initial_guess(prob.u0, prob.p, t))), hcat, mesh)
end

no_odes = length(u0_)

# has_initial_guess = prob.u0 isa AbstractVector{<:AbstractArray}
# dt ≤ 0 && throw(ArgumentError("dt must be positive"))
# no_odes, n, u0 = if has_initial_guess
# length(first(prob.u0)), (length(prob.u0) - 1), reduce(hcat, prob.u0)
# else
# length(prob.u0), Int(cld((prob.tspan[2] - prob.tspan[1]), dt)), prob.u0
# end

T = eltype(u0)
# mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1))
orders = ones(Int, no_odes)
_tspan = [prob.tspan[1], prob.tspan[2]]
iip = SciMLBase.isinplace(prob)

rhs = @closure (t, u, du) -> begin
if iip
prob.f(du, u, prob.p, t)
else
(du .= prob.f(u, prob.p, t))
end
end

if prob.f.jac === nothing
if iip
jac = (df, u, p, t) -> begin
_du = similar(u)
prob.f(_du, u, p, t)
_f = @closure (du, u) -> prob.f(du, u, p, t)
ForwardDiff.jacobian!(df, _f, _du, u)
return
end
else
jac = (df, u, p, t) -> begin
_du = prob.f(u, p, t)
_f = @closure (du, u) -> (du .= prob.f(u, p, t))
ForwardDiff.jacobian!(df, _f, _du, u)
return
end
end
else
jac = prob.f.jac
end
Drhs = @closure (t, u, df) -> jac(df, u, prob.p, t)

bcresid_prototype, _ = __get_bcresid_prototype(prob.problem_type, prob, u0)

if prob.problem_type isa TwoPointBVProblem
n_bc_a = length(first(bcresid_prototype))
n_bc_b = length(last(bcresid_prototype))
zeta = vcat(fill(first(prob.tspan), n_bc_a), fill(last(prob.tspan), n_bc_b))
bc = @closure (i, z, resid) -> begin
tmpa = copy(z)
tmpb = copy(z)
tmp_resid_a = zeros(T, n_bc_a)
tmp_resid_b = zeros(T, n_bc_b)
prob.f.bc[1](tmp_resid_a, tmpa, prob.p)
prob.f.bc[2](tmp_resid_b, tmpb, prob.p)

for j in 1:n_bc_a
if i == j
resid[1] = tmp_resid_a[j]
end
end
for j in 1:n_bc_b
if i == (j + n_bc_a)
resid[1] = tmp_resid_b[j]
end
end
end

Dbc = @closure (i, z, dbc) -> begin
for j in 1:n_bc_a
if i == j
dbc[i] = 1.0
end
end
for j in 1:n_bc_b
if i == (j + n_bc_a)
dbc[i] = 1.0
end
end
end
fixed_points = nothing
else
zeta = sort(alg.zeta)
bc = alg.bc_func
Dbc = alg.dbc_func
left_index = findlast(x -> x ≈ t₀, zeta) + 1
right_index = findfirst(x -> x ≈ t₁, zeta) - 1
fixed_points = alg.zeta[left_index:right_index]
end

if fixed_points === nothing
opt = OptionsODE(
OPT_BVPCLASS => alg.bvpclass, OPT_COLLOCATIONPTS => alg.collocationpts,
OPT_MAXSTEPS => maxiters, OPT_DIAGNOSTICOUTPUT => alg.diagnostic_output,
OPT_MAXSUBINTERVALS => alg.max_num_subintervals, OPT_RTOL => reltol)
else
opt = OptionsODE(
OPT_BVPCLASS => alg.bvpclass, OPT_COLLOCATIONPTS => alg.collocationpts,
OPT_MAXSTEPS => maxiters, OPT_DIAGNOSTICOUTPUT => alg.diagnostic_output,
OPT_MAXSUBINTERVALS => alg.max_num_subintervals,
OPT_RTOL => reltol, OPT_ADDGRIDPOINTS => fixed_points)
end

# Provide initial guess could really help COLSYS to converge
guess = @closure (t, u, du) -> begin
if iip
u .= prob.u0
prob.f(du, u, prob.p, t)
else
u .= prob.u0
(du .= prob.f(u, prob.p, t))
end
end

sol, retcode, stats = colsys(_tspan, orders, zeta, rhs, Drhs, bc, Dbc, guess, opt)

if verbose
if retcode == 0
@warn "Collocation matrix is singular"
elseif retcode == -1
@warn "The expected no. of subintervals exceeds storage(try to increase \
`OPT_MAXSUBINTERVALS`)"
elseif retcode == -2
@warn "The nonlinear iteration has not converged"
elseif retcode == -3
@warn "There is an input data error"
end
end

evalsol = evalSolution(sol, mesh)
destats = SciMLBase.DEStats(
stats["no_rhs_calls"], 0, 0, 0, stats["no_jac_calls"], 0, 0, 0, 0, 0, 0, 0, 0)

return SciMLBase.build_solution(
prob, alg, mesh, collect(Vector{eltype(evalsol)}, eachrow(evalsol));
retcode = retcode > 0 ? ReturnCode.Success : ReturnCode.Failure,
stats = destats, original = (sol, retcode, stats))
end

export BVPM2, BVPSOL, COLNEW

end
2 changes: 1 addition & 1 deletion src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export MIRKN4, MIRKN6

export Ascher1, Ascher2, Ascher3, Ascher4, Ascher5, Ascher6, Ascher7

export BVPM2, BVPSOL, COLNEW # From ODEInterface.jl
export BVPM2, BVPSOL, COLSYS, COLNEW # From ODEInterface.jl

export MIRKJacobianComputationAlgorithm, BVPJacobianAlgorithm

Expand Down
96 changes: 84 additions & 12 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Algorithms from ODEInterface.jl
# Algorithms from BVPInterface.jl
"""
BVPM2(; max_num_subintervals = 3000, method_choice = 4, diagnostic_output = 1,
error_control = 1, singular_term = nothing)
Expand Down Expand Up @@ -33,7 +33,7 @@ Fortran code for solving two-point boundary value problems. For detailed documen

!!! note

Only available if the `ODEInterface` package is loaded.
Only available if the `BVPInterface` package is loaded.
"""
struct BVPM2{S} <: BoundaryValueDiffEqAlgorithm
max_num_subintervals::Int
Expand All @@ -44,8 +44,8 @@ struct BVPM2{S} <: BoundaryValueDiffEqAlgorithm

function BVPM2(max_num_subintervals::Int, method_choice::Int, diagnostic_output::Int,
error_control::Int, singular_term::Union{Nothing, AbstractMatrix})
if Base.get_extension(@__MODULE__, :BoundaryValueDiffEqODEInterfaceExt) === nothing
error("`BVPM2` requires `ODEInterface.jl` to be loaded")
if Base.get_extension(@__MODULE__, :BoundaryValueDiffEqBVPInterfaceExt) === nothing
error("`BVPM2` requires `BVPInterface.jl` to be loaded")
end
return new{typeof(singular_term)}(max_num_subintervals, method_choice,
diagnostic_output, error_control, singular_term)
Expand Down Expand Up @@ -85,16 +85,16 @@ For detailed documentation, see

!!! note

Only available if the `ODEInterface` package is loaded.
Only available if the `BVPInterface` package is loaded.
"""
struct BVPSOL{O} <: BoundaryValueDiffEqAlgorithm
bvpclass::Int
sol_method::Int
odesolver::O

function BVPSOL(bvpclass::Int, sol_method::Int, odesolver)
if Base.get_extension(@__MODULE__, :BoundaryValueDiffEqODEInterfaceExt) === nothing
error("`BVPSOL` requires `ODEInterface.jl` to be loaded")
if Base.get_extension(@__MODULE__, :BoundaryValueDiffEqBVPInterfaceExt) === nothing
error("`BVPSOL` requires `BVPInterface.jl` to be loaded")
end
return new{typeof(odesolver)}(bvpclass, sol_method, odesolver)
end
Expand All @@ -104,6 +104,78 @@ function BVPSOL(; bvpclass = 2, sol_method = 0, odesolver = nothing)
return BVPSOL(bvpclass, sol_method, odesolver)
end

"""
COLSYS(; bvpclass = 1, collocationpts = 7, diagnostic_output = 1,
max_num_subintervals = 3000, bc_func = nothing, dbc_func = nothing,
zeta = nothing)
COLSYS(bvpclass::Int, collocationpts::Int, diagnostic_output::Int,
max_num_subintervals::Int, bc_func, dbc_func, zeta::AbstractVector)

## Keyword Arguments:

- `bvpclass`: Boundary value problem classification, default as nonlinear and
"extra sensitive", available choices:

+ `0`: Linear boundary value problem.
+ `1`: Nonlinear and regular.
+ `2`: Nonlinear and "extra sensitive" (first relax factor is rstart and the
nonlinear iteration does not rely on past convergence).
+ `3`: fail-early: return immediately upon: a) two successive non-convergences.
b) after obtaining an error estimate for the first time.

- `collocationpts`: Number of collocation points per subinterval. Require
orders[i] ≤ k ≤ 7, default as 7
- `diagnostic_output`: Diagnostic output for COLSYS, default as no printout, available
choices:

+ `-1`: Full diagnostic printout.
+ `0`: Selected printout.
+ `1`: No printout.
- `max_num_subintervals`: Number of maximal subintervals, default as 3000.
- `bc_func`: Boundary condition accord with BVPInterface.jl, only used for multi-points BVP.
- `dbc_func`: Jacobian of boundary condition accord with BVPInterface.jl, only used for multi-points BVP.
- `zeta`: The points in interval where boundary conditions are specified, only used for multi-points BVP.

A Fortran77 code solves a multi-points boundary value problems for a mixed order system of
ODEs. It incorporates a new basis representation replacing b-splines, and improvements for
the linear and nonlinear algebraic equation solvers.

!!! warning

Only supports two-point boundary value problems.

!!! note

Only available if the `BVPInterface` package is loaded.
"""
struct COLSYS <: BoundaryValueDiffEqAlgorithm
bvpclass::Int
collocationpts::Int
diagnostic_output::Int
max_num_subintervals::Int
bc_func::Union{Function, Nothing}
dbc_func::Union{Function, Nothing}
zeta::Union{AbstractVector, Nothing}

function COLSYS(bvpclass::Int, collocationpts::Int, diagnostic_output::Int,
max_num_subintervals::Int, bc_func::Union{Function, Nothing},
dbc_func::Union{Function, Nothing}, zeta::Union{AbstractVector, Nothing})
if Base.get_extension(@__MODULE__, :BoundaryValueDiffEqBVPInterfaceExt) === nothing
error("`COLSYS` requires `BVPInterface.jl` to be loaded")
end
return new(bvpclass, collocationpts, diagnostic_output,
max_num_subintervals, bc_func, dbc_func, zeta)
end
end

function COLSYS(; bvpclass::Int = 1, collocationpts::Int = 7, diagnostic_output::Int = 1,
max_num_subintervals::Int = 4000, bc_func::Union{Function, Nothing} = nothing,
dbc_func::Union{Function, Nothing} = nothing,
zeta::Union{AbstractVector, Nothing} = nothing)
return COLSYS(bvpclass, collocationpts, diagnostic_output,
max_num_subintervals, bc_func, dbc_func, zeta)
end

"""
COLNEW(; bvpclass = 1, collocationpts = 7, diagnostic_output = 1,
max_num_subintervals = 3000, bc_func = nothing, dbc_func = nothing,
Expand Down Expand Up @@ -132,8 +204,8 @@ end
+ `0`: Selected printout.
+ `1`: No printout.
- `max_num_subintervals`: Number of maximal subintervals, default as 3000.
- `bc_func`: Boundary condition accord with ODEInterface.jl, only used for multi-points BVP.
- `dbc_func`: Jacobian of boundary condition accord with ODEInterface.jl, only used for multi-points BVP.
- `bc_func`: Boundary condition accord with BVPInterface.jl, only used for multi-points BVP.
- `dbc_func`: Jacobian of boundary condition accord with BVPInterface.jl, only used for multi-points BVP.
- `zeta`: The points in interval where boundary conditions are specified, only used for multi-points BVP.

A Fortran77 code solves a multi-points boundary value problems for a mixed order system of
Expand All @@ -146,7 +218,7 @@ the linear and nonlinear algebraic equation solvers.

!!! note

Only available if the `ODEInterface` package is loaded.
Only available if the `BVPInterface` package is loaded.
"""
struct COLNEW <: BoundaryValueDiffEqAlgorithm
bvpclass::Int
Expand All @@ -160,8 +232,8 @@ struct COLNEW <: BoundaryValueDiffEqAlgorithm
function COLNEW(bvpclass::Int, collocationpts::Int, diagnostic_output::Int,
max_num_subintervals::Int, bc_func::Union{Function, Nothing},
dbc_func::Union{Function, Nothing}, zeta::Union{AbstractVector, Nothing})
if Base.get_extension(@__MODULE__, :BoundaryValueDiffEqODEInterfaceExt) === nothing
error("`COLNEW` requires `ODEInterface.jl` to be loaded")
if Base.get_extension(@__MODULE__, :BoundaryValueDiffEqBVPInterfaceExt) === nothing
error("`COLNEW` requires `BVPInterface.jl` to be loaded")
end
return new(bvpclass, collocationpts, diagnostic_output,
max_num_subintervals, bc_func, dbc_func, zeta)
Expand Down
Loading
Loading