Skip to content

Commit 9b63c96

Browse files
committed
add ControlFunction constructor
1 parent 24d4096 commit 9b63c96

File tree

2 files changed

+146
-3
lines changed

2 files changed

+146
-3
lines changed

src/problems/implicit_discrete_problems.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ dt: the time step
2727
2828
### Constructors
2929
30-
- `ImplicitDiscreteProblem(f::ODEFunction,u0,tspan,p=NullParameters();kwargs...)` :
30+
- `ImplicitDiscreteProblem(f::ImplicitDiscreteFunction,u0,tspan,p=NullParameters();kwargs...)` :
3131
Defines the discrete problem with the specified functions.
3232
- `ImplicitDiscreteProblem{isinplace,specialize}(f,u0,tspan,p=NullParameters();kwargs...)` :
3333
Defines the discrete problem with the specified functions.

src/scimlfunctions.jl

Lines changed: 145 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2176,7 +2176,7 @@ For more details on this argument, see the ODEFunction documentation.
21762176
The fields of the ControlFunction type directly match the names of the inputs.
21772177
"""
21782178
struct ControlFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP,
2179-
JP, CJP, SP, TPJ, O, TCV, CTCV,
2179+
JP, CJP, SP, TW, TWt, WP, TPJ, O, TCV, CTCV,
21802180
SYS, ID} <: AbstractControlFunction{iip}
21812181
f::F
21822182
mass_matrix::TMM
@@ -2189,10 +2189,12 @@ struct ControlFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP,
21892189
jac_prototype::JP
21902190
controljac_prototype::CJP
21912191
sparsity::SP
2192+
Wfact::TW
2193+
Wfact_t::TWt
2194+
W_prototype::WP
21922195
paramjac::TPJ
21932196
observed::O
21942197
colorvec::TCV
2195-
controlcolorvec::CTCV
21962198
sys::SYS
21972199
initialization_data::ID
21982200
end
@@ -4698,6 +4700,146 @@ function BatchIntegralFunction(f, integrand_prototype; kwargs...)
46984700
BatchIntegralFunction{calculated_iip}(f, integrand_prototype; kwargs...)
46994701
end
47004702

4703+
function ControlFunction{iip, specialize}(f;
4704+
mass_matrix = __has_mass_matrix(f) ? f.mass_matrix :
4705+
I,
4706+
analytic = __has_analytic(f) ? f.analytic : nothing,
4707+
tgrad = __has_tgrad(f) ? f.tgrad : nothing,
4708+
jac = __has_jac(f) ? f.jac : nothing,
4709+
controljac = __has_controljac(f) ? f.controljac : nothing,
4710+
jvp = __has_jvp(f) ? f.jvp : nothing,
4711+
vjp = __has_vjp(f) ? f.vjp : nothing,
4712+
jac_prototype = __has_jac_prototype(f) ?
4713+
f.jac_prototype :
4714+
nothing,
4715+
controljac_prototype = __has_controljac_prototype(f) ?
4716+
f.controljac_prototype :
4717+
nothing,
4718+
sparsity = __has_sparsity(f) ? f.sparsity :
4719+
jac_prototype,
4720+
Wfact = __has_Wfact(f) ? f.Wfact : nothing,
4721+
Wfact_t = __has_Wfact_t(f) ? f.Wfact_t : nothing,
4722+
W_prototype = __has_W_prototype(f) ? f.W_prototype : nothing,
4723+
paramjac = __has_paramjac(f) ? f.paramjac : nothing,
4724+
observed = __has_observed(f) ? f.observed :
4725+
DEFAULT_OBSERVED,
4726+
colorvec = __has_colorvec(f) ? f.colorvec : nothing,
4727+
sys = __has_sys(f) ? f.sys : nothing,
4728+
initializeprob = __has_initializeprob(f) ? f.initializeprob : nothing,
4729+
update_initializeprob! = __has_update_initializeprob!(f) ?
4730+
f.update_initializeprob! : nothing,
4731+
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
4732+
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
4733+
initialization_data = __has_initialization_data(f) ? f.initialization_data :
4734+
nothing,
4735+
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing
4736+
) where {iip,
4737+
specialize
4738+
}
4739+
if mass_matrix === I && f isa Tuple
4740+
mass_matrix = ((I for i in 1:length(f))...,)
4741+
end
4742+
4743+
if (specialize === FunctionWrapperSpecialize) &&
4744+
!(f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
4745+
error("FunctionWrapperSpecialize must be used on the problem constructor for access to u0, p, and t types!")
4746+
end
4747+
4748+
if jac === nothing && isa(jac_prototype, AbstractSciMLOperator)
4749+
if iip
4750+
jac = update_coefficients! #(J,u,p,t)
4751+
else
4752+
jac = (u, p, t) -> update_coefficients(deepcopy(jac_prototype), u, p, t)
4753+
end
4754+
end
4755+
4756+
if controljac === nothing && isa(controljac_prototype, AbstractSciMLOperator)
4757+
if iip_bc
4758+
controljac = update_coefficients! #(J,u,p,t)
4759+
else
4760+
controljac = (u, p, t) -> update_coefficients!(deepcopy(controljac_prototype), u, p, t)
4761+
end
4762+
end
4763+
4764+
if jac_prototype !== nothing && colorvec === nothing &&
4765+
ArrayInterface.fast_matrix_colors(jac_prototype)
4766+
_colorvec = ArrayInterface.matrix_colors(jac_prototype)
4767+
else
4768+
_colorvec = colorvec
4769+
end
4770+
4771+
jaciip = jac !== nothing ? isinplace(jac, 4, "jac", iip) : iip
4772+
controljaciip = controljac !== nothing ? isinplace(controljac, 4, "controljac", iip) : iip
4773+
tgradiip = tgrad !== nothing ? isinplace(tgrad, 4, "tgrad", iip) : iip
4774+
jvpiip = jvp !== nothing ? isinplace(jvp, 5, "jvp", iip) : iip
4775+
vjpiip = vjp !== nothing ? isinplace(vjp, 5, "vjp", iip) : iip
4776+
Wfactiip = Wfact !== nothing ? isinplace(Wfact, 5, "Wfact", iip) : iip
4777+
Wfact_tiip = Wfact_t !== nothing ? isinplace(Wfact_t, 5, "Wfact_t", iip) : iip
4778+
paramjaciip = paramjac !== nothing ? isinplace(paramjac, 4, "paramjac", iip) : iip
4779+
4780+
nonconforming = (jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip,
4781+
paramjaciip) .!= iip
4782+
if any(nonconforming)
4783+
nonconforming = findall(nonconforming)
4784+
functions = ["jac", "tgrad", "jvp", "vjp", "Wfact", "Wfact_t", "paramjac"][nonconforming]
4785+
throw(NonconformingFunctionsError(functions))
4786+
end
4787+
4788+
_f = prepare_function(f)
4789+
4790+
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
4791+
initdata = reconstruct_initialization_data(
4792+
initialization_data, initializeprob, update_initializeprob!,
4793+
initializeprobmap, initializeprobpmap)
4794+
4795+
if specialize === NoSpecialize
4796+
ControlFunction{iip, specialize,
4797+
Any, Any, Any, Any,
4798+
Any, Any, Any, Any, typeof(jac_prototype), typeof(controljac_prototype),
4799+
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
4800+
Any,
4801+
typeof(_colorvec),
4802+
typeof(sys), Union{Nothing, OverrideInitData}}(
4803+
_f, mass_matrix, analytic, tgrad, jac, controljac,
4804+
jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4805+
Wfact_t, W_prototype, paramjac,
4806+
observed, _colorvec, sys, initdata)
4807+
elseif specialize === false
4808+
ControlFunction{iip, FunctionWrapperSpecialize,
4809+
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
4810+
typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype),
4811+
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype),
4812+
typeof(paramjac),
4813+
typeof(observed),
4814+
typeof(_colorvec),
4815+
typeof(sys), typeof(initdata)}(_f, mass_matrix,
4816+
analytic, tgrad, jac, controljac,
4817+
jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4818+
Wfact_t, W_prototype, paramjac,
4819+
observed, _colorvec, sys, initdata)
4820+
else
4821+
ControlFunction{iip, specialize,
4822+
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
4823+
typeof(jac), typeof(controljac), typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(controljac_prototype),
4824+
typeof(sparsity), typeof(Wfact), typeof(Wfact_t), typeof(W_prototype),
4825+
typeof(paramjac),
4826+
typeof(observed),
4827+
typeof(_colorvec),
4828+
typeof(sys), typeof(initdata)}(
4829+
_f, mass_matrix, analytic, tgrad,
4830+
jac, controljac, jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4831+
Wfact_t, W_prototype, paramjac,
4832+
observed, _colorvec, sys, initdata)
4833+
end
4834+
end
4835+
4836+
function ODEFunction{iip}(f; kwargs...) where {iip}
4837+
ODEFunction{iip, FullSpecialize}(f; kwargs...)
4838+
end
4839+
ODEFunction{iip}(f::ODEFunction; kwargs...) where {iip} = f
4840+
ODEFunction(f; kwargs...) = ODEFunction{isinplace(f, 4), FullSpecialize}(f; kwargs...)
4841+
ODEFunction(f::ODEFunction; kwargs...) = f
4842+
47014843
########## Utility functions
47024844

47034845
function sys_or_symbolcache(sys, syms, paramsyms, indepsym = nothing)
@@ -4731,6 +4873,7 @@ __has_Wfact_t(f) = isdefined(f, :Wfact_t)
47314873
__has_W_prototype(f) = isdefined(f, :W_prototype)
47324874
__has_paramjac(f) = isdefined(f, :paramjac)
47334875
__has_jac_prototype(f) = isdefined(f, :jac_prototype)
4876+
__has_controljac_prototype(f) = isdefined(f, :controljac_prototype)
47344877
__has_sparsity(f) = isdefined(f, :sparsity)
47354878
__has_mass_matrix(f) = isdefined(f, :mass_matrix)
47364879
__has_syms(f) = isdefined(f, :syms)

0 commit comments

Comments
 (0)