Skip to content

Commit 3fef7cc

Browse files
Merge pull request #1065 from ErikQQY/qqy/bvp_opt
Allow Optimization interface in BVPFunction and BVProblem
2 parents a9735d6 + c038fb2 commit 3fef7cc

File tree

4 files changed

+45
-12
lines changed

4 files changed

+45
-12
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
Manifest.toml
22
.*.swp
3+
.DS_Store
34

45
# vscode stuff
56
.vscode

src/problems/bvp_problems.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,19 @@ every solve call.
113113
doesn't store array size as part of type information. If we can't reliably infer this,
114114
we set it to `Nothing`. Downstreams solvers must be setup to deal with this case.
115115
"""
116-
struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
116+
struct BVProblem{uType, tType, isinplace, nlls, P, F, LC, UC, PT, K} <:
117117
AbstractBVProblem{uType, tType, isinplace, nlls}
118118
f::F
119119
u0::uType
120120
tspan::tType
121121
p::P
122+
lcons::LC
123+
ucons::UC
122124
problem_type::PT
123125
kwargs::K
124126

125127
@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, u0, tspan,
126-
p = NullParameters(); problem_type = nothing, nlls = nothing,
128+
p = NullParameters(); lcons = nothing, ucons = nothing, problem_type = nothing, nlls = nothing,
127129
kwargs...) where {iip, TP}
128130
_u0 = prepare_initial_state(u0)
129131
_tspan = promote_tspan(tspan)
@@ -172,8 +174,8 @@ struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
172174
_nlls = _unwrap_val(nlls)
173175
end
174176

175-
return new{typeof(_u0), typeof(_tspan), iip, _nlls, typeof(p), typeof(f),
176-
typeof(problem_type), typeof(kwargs)}(f, _u0, _tspan, p, problem_type, kwargs)
177+
return new{typeof(_u0), typeof(_tspan), iip, _nlls, typeof(p), typeof(f), typeof(lcons), typeof(ucons),
178+
typeof(problem_type), typeof(kwargs)}(f, _u0, _tspan, p, lcons, ucons, problem_type, kwargs)
177179
end
178180

179181
function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip}

src/scimlfunctions.jl

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2230,6 +2230,9 @@ with respect to time, and more. For all cases, `u0` is the initial condition,
22302230
22312231
```julia
22322232
BVPFunction{iip, specialize}(f, bc;
2233+
cost = __has_cost(f) ? f.cost : nothing,
2234+
equality = __has_equality(f) ? f.equality : nothing,
2235+
inequality = __has_inequality(f) ? f.inequality : nothing,
22332236
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
22342237
analytic = __has_analytic(f) ? f.analytic : nothing,
22352238
tgrad= __has_tgrad(f) ? f.tgrad : nothing,
@@ -2257,6 +2260,11 @@ See the section on `iip` for more details on in-place vs out-of-place handling.
22572260
All of the remaining functions are optional for improving or accelerating
22582261
the usage of `f` and `bc`. These include:
22592262
2263+
- `cost(u, p)`: the target to be minimized, similar with the `cost` function
2264+
in [`OptimizationFunction`](@ref). This is used to define the objective function
2265+
of the BVP, which can be minimized by optimization solvers.
2266+
- `equality(res, u, t)`: equality constraints functions for the BVP.
2267+
- `inequality(res, u, t)`: inequality contraints functions for the BVP.
22602268
- `mass_matrix`: the mass matrix `M` represented in the BVP function. Can be used
22612269
to determine that the equation is actually a BVP for differential algebraic equation (DAE)
22622270
if `M` is singular.
@@ -2310,11 +2318,14 @@ For more details on this argument, see the ODEFunction documentation.
23102318
23112319
The fields of the BVPFunction type directly match the names of the inputs.
23122320
"""
2313-
struct BVPFunction{iip, specialize, twopoint, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP,
2321+
struct BVPFunction{iip, specialize, twopoint, F, BF, C, EC, IC, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP,
23142322
JP, BCJP, BCRP, SP, TW, TWt, TPJ, O, TCV, BCTCV,
23152323
SYS, ID} <: AbstractBVPFunction{iip, twopoint}
23162324
f::F
23172325
bc::BF
2326+
cost::C
2327+
equality::EC
2328+
inequality::IC
23182329
mass_matrix::TMM
23192330
analytic::Ta
23202331
tgrad::Tt
@@ -4326,6 +4337,9 @@ function MultiObjectiveOptimizationFunction{iip}(f, adtype::AbstractADType = NoA
43264337
end
43274338

43284339
function BVPFunction{iip, specialize, twopoint}(f, bc;
4340+
cost = __has_cost(f) ? f.cost : nothing,
4341+
equality = __has_equality(f) ? f.equality : nothing,
4342+
inequality = __has_inequality(f) ? f.inequality : nothing,
43294343
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
43304344
analytic = __has_analytic(f) ? f.analytic : nothing,
43314345
tgrad = __has_tgrad(f) ? f.tgrad : nothing,
@@ -4426,14 +4440,17 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
44264440
Wfactiip = Wfact !== nothing ? isinplace(Wfact, 5, "Wfact", iip) : iip
44274441
Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 5, "Wfact_t", iip) : iip
44284442
paramjaciip = paramjac !== nothing ? isinplace(paramjac, 4, "paramjac", iip) : iip
4443+
costiip = cost !== nothing ? isinplace(cost, 2, "cost", iip) : iip
4444+
equalityiip = equality !== nothing ? isinplace(equality, 3, "equality", iip) : iip
4445+
inequalityiip = inequality !== nothing ? isinplace(inequality, 3, "inequality", iip) : iip
44294446

44304447
nonconforming = (bciip, jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip,
4431-
paramjaciip) .!= iip
4448+
paramjaciip, costiip, equalityiip, inequalityiip) .!= iip
44324449
bc_nonconforming = bcjaciip .!= bciip
44334450
if any(nonconforming)
44344451
nonconforming = findall(nonconforming)
4435-
functions = ["bc", "jac", "bcjac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t",
4436-
"paramjac"][nonconforming]
4452+
functions = ["bc", "jac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t",
4453+
"paramjac", "cost", "equality", "inequality"][nonconforming]
44374454
throw(NonconformingFunctionsError(functions))
44384455
end
44394456

@@ -4464,24 +4481,25 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
44644481
sys = something(sys, SymbolCache(syms, paramsyms, indepsym))
44654482

44664483
if specialize === NoSpecialize
4467-
BVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any,
4484+
BVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any, Any, Any, Any,
44684485
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any,
44694486
Any,
44704487
Any, typeof(_colorvec), typeof(_bccolorvec), Any, Any}(
4471-
_f, bc, mass_matrix,
4488+
_f, bc, cost, equality, inequality, mass_matrix,
44724489
analytic, tgrad, jac, bcjac, jvp, vjp, jac_prototype,
44734490
bcjac_prototype, bcresid_prototype,
44744491
sparsity, Wfact, Wfact_t, paramjac, observed,
44754492
_colorvec, _bccolorvec, sys, initialization_data)
44764493
else
4477-
BVPFunction{iip, specialize, twopoint, typeof(_f), typeof(bc),
4494+
BVPFunction{iip, specialize, twopoint, typeof(_f), typeof(bc), typeof(cost),
4495+
typeof(equality), typeof(inequality),
44784496
typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac),
44794497
typeof(bcjac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
44804498
typeof(bcjac_prototype), typeof(bcresid_prototype), typeof(sparsity),
44814499
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
44824500
typeof(_colorvec), typeof(_bccolorvec), typeof(sys),
44834501
typeof(initialization_data)}(
4484-
_f, bc, mass_matrix, analytic,
4502+
_f, bc, cost, equality, inequality, mass_matrix, analytic,
44854503
tgrad, jac, bcjac, jvp, vjp,
44864504
jac_prototype, bcjac_prototype, bcresid_prototype, sparsity,
44874505
Wfact, Wfact_t, paramjac,
@@ -4937,6 +4955,9 @@ __has_initialization_data(f) = isdefined(f, :initialization_data)
49374955
__has_polynomialize(f) = isdefined(f, :polynomialize)
49384956
__has_unpolynomialize(f) = isdefined(f, :unpolynomialize)
49394957
__has_denominator(f) = isdefined(f, :denominator)
4958+
__has_cost(f) = isdefined(f, :cost)
4959+
__has_equality(f) = isdefined(f, :equality)
4960+
__has_inequality(f) = isdefined(f, :inequality)
49404961

49414962
# compatibility
49424963
has_invW(f::AbstractSciMLFunction) = false

test/function_building_error_messages.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,15 @@ BVPFunction(bfiip, bciip, vjp = bvjp)
650650

651651
@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bciip, vjp = bvjp)
652652

653+
BVPFunction(bfiip, bciip, cost = (x, p) -> 0.0)
654+
@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, cost = x -> 0.0)
655+
equality(u, p) = u
656+
inequality(u, p) = u
657+
@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, cost = (x, p) -> 0.0, equality = equality, inequality = inequality)
658+
equality(res, u, p) = (res .= u)
659+
inequality(res, u, p) = (res .= u)
660+
BVPFunction(bfiip, bciip, cost = (x, p) -> 0.0, equality = equality, inequality = inequality)
661+
653662
# DynamicalBVPFunction
654663

655664
dbfoop(du, u, p, t) = u

0 commit comments

Comments
 (0)