Skip to content

Commit 8f3b3b9

Browse files
committed
Allow Optimization interface in BVPFunction and BVProblem
1 parent a613095 commit 8f3b3b9

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

.DS_Store

6 KB
Binary file not shown.

src/problems/bvp_problems.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,19 @@ every solve call.
113113
doesn't store array size as part of type information. If we can't reliably infer this,
114114
we set it to `Nothing`. Downstreams solvers must be setup to deal with this case.
115115
"""
116-
struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
116+
struct BVProblem{uType, tType, isinplace, nlls, P, F, LC, UC, PT, K} <:
117117
AbstractBVProblem{uType, tType, isinplace, nlls}
118118
f::F
119119
u0::uType
120120
tspan::tType
121121
p::P
122+
lcons::LC
123+
ucons::UC
122124
problem_type::PT
123125
kwargs::K
124126

125127
@add_kwonly function BVProblem{iip}(f::AbstractBVPFunction{iip, TP}, u0, tspan,
126-
p = NullParameters(); problem_type = nothing, nlls = nothing,
128+
p = NullParameters(); lcons = nothing, ucons = nothing, problem_type = nothing, nlls = nothing,
127129
kwargs...) where {iip, TP}
128130
_u0 = prepare_initial_state(u0)
129131
_tspan = promote_tspan(tspan)
@@ -172,8 +174,8 @@ struct BVProblem{uType, tType, isinplace, nlls, P, F, PT, K} <:
172174
_nlls = _unwrap_val(nlls)
173175
end
174176

175-
return new{typeof(_u0), typeof(_tspan), iip, _nlls, typeof(p), typeof(f),
176-
typeof(problem_type), typeof(kwargs)}(f, _u0, _tspan, p, problem_type, kwargs)
177+
return new{typeof(_u0), typeof(_tspan), iip, _nlls, typeof(p), typeof(f), typeof(lcons), typeof(ucons),
178+
typeof(problem_type), typeof(kwargs)}(f, _u0, _tspan, p, lcons, ucons, problem_type, kwargs)
177179
end
178180

179181
function BVProblem{iip}(f, bc, u0, tspan, p = NullParameters(); kwargs...) where {iip}

src/scimlfunctions.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2310,11 +2310,12 @@ For more details on this argument, see the ODEFunction documentation.
23102310
23112311
The fields of the BVPFunction type directly match the names of the inputs.
23122312
"""
2313-
struct BVPFunction{iip, specialize, twopoint, F, BF, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP,
2313+
struct BVPFunction{iip, specialize, twopoint, F, BF, C, TMM, Ta, Tt, TJ, BCTJ, JVP, VJP,
23142314
JP, BCJP, BCRP, SP, TW, TWt, TPJ, O, TCV, BCTCV,
23152315
SYS, ID} <: AbstractBVPFunction{iip, twopoint}
23162316
f::F
23172317
bc::BF
2318+
cost::C
23182319
mass_matrix::TMM
23192320
analytic::Ta
23202321
tgrad::Tt
@@ -4326,6 +4327,7 @@ function MultiObjectiveOptimizationFunction{iip}(f, adtype::AbstractADType = NoA
43264327
end
43274328

43284329
function BVPFunction{iip, specialize, twopoint}(f, bc;
4330+
cost = (x, p) -> zero(x),
43294331
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix : I,
43304332
analytic = __has_analytic(f) ? f.analytic : nothing,
43314333
tgrad = __has_tgrad(f) ? f.tgrad : nothing,
@@ -4464,7 +4466,7 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
44644466
sys = something(sys, SymbolCache(syms, paramsyms, indepsym))
44654467

44664468
if specialize === NoSpecialize
4467-
BVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any,
4469+
BVPFunction{iip, specialize, twopoint, Any, Any, Any, Any, Any, Any,
44684470
Any, Any, Any, Any, Any, Any, Any, Any, Any, Any,
44694471
Any,
44704472
Any, typeof(_colorvec), typeof(_bccolorvec), Any, Any}(
@@ -4474,14 +4476,14 @@ function BVPFunction{iip, specialize, twopoint}(f, bc;
44744476
sparsity, Wfact, Wfact_t, paramjac, observed,
44754477
_colorvec, _bccolorvec, sys, initialization_data)
44764478
else
4477-
BVPFunction{iip, specialize, twopoint, typeof(_f), typeof(bc),
4479+
BVPFunction{iip, specialize, twopoint, typeof(_f), typeof(bc), typeof(cost),
44784480
typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac),
44794481
typeof(bcjac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
44804482
typeof(bcjac_prototype), typeof(bcresid_prototype), typeof(sparsity),
44814483
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
44824484
typeof(_colorvec), typeof(_bccolorvec), typeof(sys),
44834485
typeof(initialization_data)}(
4484-
_f, bc, mass_matrix, analytic,
4486+
_f, bc, cost, mass_matrix, analytic,
44854487
tgrad, jac, bcjac, jvp, vjp,
44864488
jac_prototype, bcjac_prototype, bcresid_prototype, sparsity,
44874489
Wfact, Wfact_t, paramjac,

0 commit comments

Comments
 (0)