Skip to content

Commit 976a01e

Browse files
committed
Allow equality and inequality
1 parent 5c5d181 commit 976a01e

File tree

3 files changed

+37
-8
lines changed

3 files changed

+37
-8
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
Manifest.toml
22
.*.swp
3+
.DS_Store
34

45
# vscode stuff
56
.vscode

src/scimlfunctions.jl

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2230,6 +2230,9 @@ with respect to time, and more. For all cases, `u0` is the initial condition,
22302230
22312231
```julia
22322232
BVPFunction{iip, specialize}(f, bc;
2233+
cost = __has_cost(f) ? f.cost : nothing,
2234+
equality = __has_equality(f) ? f.equality : nothing,
2235+
inequality = __has_inequality(f) ? f.inequality : nothing,
22332236
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
22342237
analytic = __has_analytic(f) ? f.analytic : nothing,
22352238
tgrad= __has_tgrad(f) ? f.tgrad : nothing,
@@ -2257,6 +2260,11 @@ See the section on `iip` for more details on in-place vs out-of-place handling.
22572260
All of the remaining functions are optional for improving or accelerating
22582261
the usage of `f` and `bc`. These include:
22592262
2263+
- `cost(u, p)`: the target to be minimized, similar with the `cost` function
2264+
in [`OptimizationFunction`](@ref). This is used to define the objective function
2265+
of the BVP, which can be minimized by optimization solvers.
2266+
- `equality(res, u, t)`: equality constraints functions for the BVP.
2267+
- `inequality(res, u, t)`: inequality contraints functions for the BVP.
22602268
- `mass_matrix`: the mass matrix `M` represented in the BVP function. Can be used
22612269
to determine that the equation is actually a BVP for differential algebraic equation (DAE)
22622270
if `M` is singular.
@@ -2310,12 +2318,14 @@ For more details on this argument, see the ODEFunction documentation.
23102318
23112319
The fields of the BVPFunction type directly match the names of the inputs.
23122320
"""
2313-
struct BVPFunction{iip, specialize, twopoint, F, BF, C, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP,
2321+
struct BVPFunction{iip, specialize, twopoint, F, BF, C, EC, IC, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP,
23142322
JP, BCJP, BCRP, SP, TW, TWt, TPJ, O, TCV, BCTCV,
23152323
SYS, ID} <: AbstractBVPFunction{iip, twopoint}
23162324
f::F
23172325
bc::BF
23182326
cost::C
2327+
equality::EC
2328+
inequality::IC
23192329
mass_matrix::TMM
23202330
analytic::Ta
23212331
tgrad::Tt
@@ -4327,7 +4337,9 @@ function MultiObjectiveOptimizationFunction{iip}(f, adtype::AbstractADType = NoA
43274337
end
43284338

43294339
function BVPFunction{iip, specialize, twopoint}(f, bc;
4330-
cost = (x, p) -> zero(x),
4340+
cost = __has_cost(f) ? f.cost : nothing,
4341+
equality = __has_equality(f) ? f.equality : nothing,
4342+
inequality = __has_inequality(f) ? f.inequality : nothing,
43314343
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
43324344
analytic = __has_analytic(f) ? f.analytic : nothing,
43334345
tgrad = __has_tgrad(f) ? f.tgrad : nothing,
@@ -4428,14 +4440,17 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
44284440
Wfactiip = Wfact !== nothing ? isinplace(Wfact, 5, "Wfact", iip) : iip
44294441
Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 5, "Wfact_t", iip) : iip
44304442
paramjaciip = paramjac !== nothing ? isinplace(paramjac, 4, "paramjac", iip) : iip
4443+
costiip = cost !== nothing ? isinplace(cost, 2, "cost", iip) : iip
4444+
equalityiip = equality !== nothing ? isinplace(equality, 3, "equality", iip) : iip
4445+
inequalityiip = inequality !== nothing ? isinplace(inequality, 3, "inequality", iip) : iip
44314446

44324447
nonconforming = (bciip, jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip,
4433-
paramjaciip) .!= iip
4448+
paramjaciip, costiip, equalityiip, inequalityiip) .!= iip
44344449
bc_nonconforming = bcjaciip .!= bciip
44354450
if any(nonconforming)
44364451
nonconforming = findall(nonconforming)
4437-
functions = ["bc", "jac", "bcjac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t",
4438-
"paramjac"][nonconforming]
4452+
functions = ["bc", "jac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t",
4453+
"paramjac", "cost", "equality", "inequality"][nonconforming]
44394454
throw(NonconformingFunctionsError(functions))
44404455
end
44414456

@@ -4466,24 +4481,25 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
44664481
sys = something(sys, SymbolCache(syms, paramsyms, indepsym))
44674482

44684483
if specialize === NoSpecialize
4469-
BVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any, Any,
4484+
BVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any, Any, Any, Any,
44704485
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any,
44714486
Any,
44724487
Any, typeof(_colorvec), typeof(_bccolorvec), Any, Any}(
4473-
_f, bc, mass_matrix,
4488+
_f, bc, cost, equality, inequality, mass_matrix,
44744489
analytic, tgrad, jac, bcjac, jvp, vjp, jac_prototype,
44754490
bcjac_prototype, bcresid_prototype,
44764491
sparsity, Wfact, Wfact_t, paramjac, observed,
44774492
_colorvec, _bccolorvec, sys, initialization_data)
44784493
else
44794494
BVPFunction{iip, specialize, twopoint, typeof(_f), typeof(bc), typeof(cost),
4495+
typeof(equality), typeof(inequality),
44804496
typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac),
44814497
typeof(bcjac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
44824498
typeof(bcjac_prototype), typeof(bcresid_prototype), typeof(sparsity),
44834499
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
44844500
typeof(_colorvec), typeof(_bccolorvec), typeof(sys),
44854501
typeof(initialization_data)}(
4486-
_f, bc, cost, mass_matrix, analytic,
4502+
_f, bc, cost, equality, inequality, mass_matrix, analytic,
44874503
tgrad, jac, bcjac, jvp, vjp,
44884504
jac_prototype, bcjac_prototype, bcresid_prototype, sparsity,
44894505
Wfact, Wfact_t, paramjac,
@@ -4939,6 +4955,9 @@ __has_initialization_data(f) = isdefined(f, :initialization_data)
49394955
__has_polynomialize(f) = isdefined(f, :polynomialize)
49404956
__has_unpolynomialize(f) = isdefined(f, :unpolynomialize)
49414957
__has_denominator(f) = isdefined(f, :denominator)
4958+
__has_cost(f) = isdefined(f, :cost)
4959+
__has_equality(f) = isdefined(f, :equality)
4960+
__has_inequality(f) = isdefined(f, :inequality)
49424961

49434962
# compatibility
49444963
has_invW(f::AbstractSciMLFunction) = false

test/function_building_error_messages.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,15 @@ BVPFunction(bfiip, bciip, vjp = bvjp)
650650

651651
@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfoop, bciip, vjp = bvjp)
652652

653+
BVPFunction(bfiip, bciip, cost = (x, p) -> 0.0)
654+
@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, cost = x -> 0.0)
655+
equality(u, p) = u
656+
inequality(u, p) = u
657+
@test_throws SciMLBase.NonconformingFunctionsError BVPFunction(bfiip, bciip, cost = (x, p) -> 0.0, equality_constraints = equality, inequality_constraints = inequality)
658+
equality(res, u, p) = (res .= u)
659+
inequality(res, u, p) = (res .= u)
660+
BVPFunction(bfiip, bciip, cost = (x, p) -> 0.0, equality_constraints = equality, inequality_constraints = inequality)
661+
653662
# DynamicalBVPFunction
654663

655664
dbfoop(du, u, p, t) = u

0 commit comments

Comments
 (0)