Skip to content

Commit 6526a57

Browse files
committed
Support for u0 initial guess function
1 parent ed4234d commit 6526a57

File tree

5 files changed

+81
-23
lines changed

5 files changed

+81
-23
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1313
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1414
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
1515
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
16+
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
1617
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
1718
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1819
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
@@ -22,6 +23,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2223
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2324
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2425
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
26+
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
2527
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
2628
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
2729

@@ -56,6 +58,7 @@ SciMLBase = "2.5"
5658
Setfield = "1"
5759
SparseArrays = "1.9"
5860
SparseDiffTools = "2.9"
61+
Tricks = "0.1"
5962
TruncatedStacktraces = "1"
6063
UnPack = "1"
6164
julia = "1.9"

src/BoundaryValueDiffEq.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidat
55
@recompile_invalidations begin
66
using ADTypes, Adapt, BandedMatrices, DiffEqBase, ForwardDiff, LinearAlgebra,
77
NonlinearSolve, PreallocationTools, Preferences, RecursiveArrayTools, Reexport,
8-
SciMLBase, Setfield, SparseArrays, SparseDiffTools
8+
SciMLBase, Setfield, SparseArrays, SparseDiffTools, Tricks
99

1010
import ADTypes: AbstractADType
1111
import ArrayInterface: matrix_colors,
@@ -22,6 +22,31 @@ end
2222

2323
@reexport using ADTypes, DiffEqBase, NonlinearSolve, SparseDiffTools, SciMLBase
2424

25+
# TODO: Upstream
26+
# For BVPs we want to propagate even a function u0
27+
function DiffEqBase.get_concrete_u0(prob::BVProblem, isadapt, t0, kwargs)
28+
if haskey(kwargs, :u0)
29+
u0 = kwargs[:u0]
30+
else
31+
u0 = prob.u0
32+
end
33+
34+
isadapt && eltype(u0) <: Integer && (u0 = float.(u0))
35+
36+
_u0 = DiffEqBase.handle_distribution_u0(u0)
37+
38+
if isinplace(prob) && (_u0 isa Number || _u0 isa DiffEqBase.SArray)
39+
throw(DiffEqBase.IncompatibleInitialConditionError())
40+
end
41+
42+
if _u0 isa Tuple
43+
throw(DiffEqBase.TupleStateError())
44+
end
45+
46+
return _u0
47+
end
48+
# End of Upstream
49+
2550
include("types.jl")
2651
include("utils.jl")
2752
include("algorithms.jl")

src/solve/mirk.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,20 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
3535
abstol = 1e-3, adaptive = true, kwargs...)
3636
@set! alg.jac_alg = concrete_jacobian_algorithm(alg.jac_alg, prob, alg)
3737
iip = isinplace(prob)
38+
3839
_, T, M, n, X = __extract_problem_details(prob; dt, check_positive_dt = true)
40+
# NOTE: Assumes the user provided initial guess is on a uniform mesh
41+
mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1))
42+
43+
mesh_dt = diff(mesh)
44+
3945
chunksize = pickchunksize(M * (n + 1))
4046

4147
__alloc = x -> __maybe_allocate_diffcache(vec(x), chunksize, alg.jac_alg)
4248

4349
fᵢ_cache = __alloc(similar(X))
4450
fᵢ₂_cache = vec(similar(X))
4551

46-
# NOTE: Assumes the user provided initial guess is on a uniform mesh
47-
mesh = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1))
48-
mesh_dt = diff(mesh)
49-
5052
defect_threshold = T(0.1) # TODO: Allow user to specify these
5153
MxNsub = 3000 # TODO: Allow user to specify these
5254

@@ -100,7 +102,9 @@ function SciMLBase.__init(prob::BVProblem, alg::AbstractMIRK; dt = 0.0,
100102
vecf, vecbc
101103
end
102104

103-
return MIRKCache{iip, T}(alg_order(alg), stage, M, size(X), f, bc, prob,
105+
prob_ = !(prob.u0 isa AbstractArray) ? remake(prob; u0 = X) : prob
106+
107+
return MIRKCache{iip, T}(alg_order(alg), stage, M, size(X), f, bc, prob_,
104108
prob.problem_type, prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt,
105109
k_discrete, k_interp, y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, new_stages,
106110
resid₁_size, (; defect_threshold, MxNsub, abstol, dt, adaptive, kwargs...))

src/solve/multiple_shooting.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ function __solve(prob::BVProblem, _alg::MultipleShooting; odesolve_kwargs = (;),
4949
ode_cache_loss_fn; kwargs..., verbose, odesolve_kwargs...)
5050
else
5151
u_at_nodes = __multiple_shooting_initialize!(nodes, u_at_nodes, prob, alg,
52-
cur_nshoot, all_nshoots[i - 1], ig, ode_cache_loss_fn; kwargs..., verbose,
53-
odesolve_kwargs...)
52+
cur_nshoot, all_nshoots[i - 1], ig, ode_cache_loss_fn, u0; kwargs...,
53+
verbose, odesolve_kwargs...)
5454
end
5555

5656
if prob.problem_type isa TwoPointBVProblem
@@ -362,9 +362,13 @@ end
362362
resize!(nodes, nshoots + 1)
363363
nodes .= range(tspan[1], tspan[2]; length = nshoots + 1)
364364

365-
N = length(first(u0))
366-
u_at_nodes = similar(first(u0), (nshoots + 1) * N)
367-
recursive_flatten!(u_at_nodes, u0)
365+
# NOTE: We don't check `u0 isa Function` since `u0` in-principle can be a callable
366+
# struct
367+
u0_ = u0 isa AbstractArray ? u0 : [__initial_guess(u0, prob.p, t) for t in nodes]
368+
369+
N = length(first(u0_))
370+
u_at_nodes = similar(first(u0_), (nshoots + 1) * N)
371+
recursive_flatten!(u_at_nodes, u0_)
368372

369373
return u_at_nodes
370374
end
@@ -401,7 +405,8 @@ end
401405
end
402406
else
403407
@warn "Initialization using odesolve failed. Initializing using 0s. It is \
404-
recommended to provide an `initial_guess` in this case."
408+
recommended to provide an initial guess function via \
409+
`u0 = <function>(p, t)` or `u0 = <function>(t)` in this case."
405410
fill!(u_at_nodes, 0)
406411
end
407412

@@ -410,16 +415,16 @@ end
410415

411416
# Grid coarsening
412417
@views function __multiple_shooting_initialize!(nodes, u_at_nodes_prev, prob, alg,
413-
nshoots, old_nshoots, ig, odecache_; kwargs...)
414-
@unpack f, u0, tspan, p = prob
418+
nshoots, old_nshoots, ig, odecache_, u0; kwargs...)
419+
@unpack f, tspan, p = prob
415420
prev_nodes = copy(nodes)
416421
odecache = odecache_ isa Vector ? first(odecache_) : odecache_
417422

418423
resize!(nodes, nshoots + 1)
419424
nodes .= range(tspan[1], tspan[2]; length = nshoots + 1)
420-
N = _unwrap_val(ig) ? length(first(u0)) : length(u0)
425+
N = length(u0)
421426

422-
u_at_nodes = similar(_unwrap_val(ig) ? first(u0) : u0, N + nshoots * N)
427+
u_at_nodes = similar(u0, N + nshoots * N)
423428
u_at_nodes[1:N] .= u_at_nodes_prev[1:N]
424429
u_at_nodes[(end - N + 1):end] .= u_at_nodes_prev[(end - N + 1):end]
425430

src/utils.jl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -139,17 +139,38 @@ function __extract_problem_details(prob, u0::AbstractArray; dt = 0.0,
139139
t₀, t₁ = prob.tspan
140140
return Val(false), eltype(u0), length(u0), Int(cld(t₁ - t₀, dt)), prob.u0
141141
end
142-
function __extract_problem_details(prob, ::F; kwargs...) where {F <: Function}
143-
throw(ArgumentError("passing `u0` as a function is not supported yet. Curently we only \
144-
support AbstractArray or Vector of AbstractArrays as input! \
145-
Use the latter format for passing in initial guess!"))
142+
function __extract_problem_details(prob, f::F; dt = 0.0,
143+
check_positive_dt::Bool = false) where {F <: Function}
144+
# Problem passes in a initial guess function
145+
check_positive_dt && dt 0 && throw(ArgumentError("dt must be positive"))
146+
u0 = __initial_guess(f, prob.p, prob.tspan[1])
147+
t₀, t₁ = prob.tspan
148+
return Val(true), eltype(u0), length(u0), Int(cld(t₁ - t₀, dt)), u0
146149
end
147150

148-
__initial_state_from_prob(prob::BVProblem, mesh) = __initial_state_from_prob(prob.u0, mesh)
149-
__initial_state_from_prob(u0::AbstractArray, mesh) = [copy(vec(u0)) for _ in mesh]
150-
function __initial_state_from_prob(u0::AbstractVector{<:AbstractVector}, _)
151+
function __initial_guess(f::F, p::P, t::T) where {F, P, T}
152+
if static_hasmethod(f, Tuple{P, T})
153+
return f(p, t)
154+
elseif static_hasmethod(f, Tuple{T})
155+
return f(t)
156+
else
157+
throw(ArgumentError("`initial_guess` must be a function of the form `f(p, t)` or \
158+
`f(t)`"))
159+
end
160+
end
161+
162+
function __initial_state_from_prob(prob::BVProblem, mesh)
163+
return __initial_state_from_prob(prob, prob.u0, mesh)
164+
end
165+
function __initial_state_from_prob(::BVProblem, u0::AbstractArray, mesh)
166+
return [copy(vec(u0)) for _ in mesh]
167+
end
168+
function __initial_state_from_prob(::BVProblem, u0::AbstractVector{<:AbstractVector}, _)
151169
return [copy(vec(u)) for u in u0]
152170
end
171+
function __initial_state_from_prob(prob::BVProblem, f::F, mesh) where {F}
172+
return [__initial_guess(f, prob.p, t) for t in mesh]
173+
end
153174

154175
function __get_bcresid_prototype(prob::BVProblem, u)
155176
return __get_bcresid_prototype(prob.problem_type, prob, u)

0 commit comments

Comments
 (0)