@@ -2176,7 +2176,7 @@ For more details on this argument, see the ODEFunction documentation.
2176
2176
The fields of the ControlFunction type directly match the names of the inputs.
2177
2177
"""
2178
2178
struct ControlFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP,
2179
- JP, CJP, SP, TPJ, O, TCV, CTCV,
2179
+ JP, CJP, SP, TW, TWt, WP, TPJ, O, TCV, CTCV,
2180
2180
SYS, ID} <: AbstractControlFunction{iip}
2181
2181
f:: F
2182
2182
mass_matrix:: TMM
@@ -2189,10 +2189,12 @@ struct ControlFunction{iip, specialize, F, TMM, Ta, Tt, TJ, CTJ, JVP, VJP,
2189
2189
jac_prototype:: JP
2190
2190
controljac_prototype:: CJP
2191
2191
sparsity:: SP
2192
+ Wfact:: TW
2193
+ Wfact_t:: TWt
2194
+ W_prototype:: WP
2192
2195
paramjac:: TPJ
2193
2196
observed:: O
2194
2197
colorvec:: TCV
2195
- controlcolorvec:: CTCV
2196
2198
sys:: SYS
2197
2199
initialization_data:: ID
2198
2200
end
@@ -4698,6 +4700,146 @@ function BatchIntegralFunction(f, integrand_prototype; kwargs...)
4698
4700
BatchIntegralFunction {calculated_iip} (f, integrand_prototype; kwargs... )
4699
4701
end
4700
4702
4703
+ function ControlFunction {iip, specialize} (f;
4704
+ mass_matrix = __has_mass_matrix (f) ? f. mass_matrix :
4705
+ I,
4706
+ analytic = __has_analytic (f) ? f. analytic : nothing ,
4707
+ tgrad = __has_tgrad (f) ? f. tgrad : nothing ,
4708
+ jac = __has_jac (f) ? f. jac : nothing ,
4709
+ controljac = __has_controljac (f) ? f. controljac : nothing ,
4710
+ jvp = __has_jvp (f) ? f. jvp : nothing ,
4711
+ vjp = __has_vjp (f) ? f. vjp : nothing ,
4712
+ jac_prototype = __has_jac_prototype (f) ?
4713
+ f. jac_prototype :
4714
+ nothing ,
4715
+ controljac_prototype = __has_controljac_prototype (f) ?
4716
+ f. controljac_prototype :
4717
+ nothing ,
4718
+ sparsity = __has_sparsity (f) ? f. sparsity :
4719
+ jac_prototype,
4720
+ Wfact = __has_Wfact (f) ? f. Wfact : nothing ,
4721
+ Wfact_t = __has_Wfact_t (f) ? f. Wfact_t : nothing ,
4722
+ W_prototype = __has_W_prototype (f) ? f. W_prototype : nothing ,
4723
+ paramjac = __has_paramjac (f) ? f. paramjac : nothing ,
4724
+ observed = __has_observed (f) ? f. observed :
4725
+ DEFAULT_OBSERVED,
4726
+ colorvec = __has_colorvec (f) ? f. colorvec : nothing ,
4727
+ sys = __has_sys (f) ? f. sys : nothing ,
4728
+ initializeprob = __has_initializeprob (f) ? f. initializeprob : nothing ,
4729
+ update_initializeprob! = __has_update_initializeprob! (f) ?
4730
+ f. update_initializeprob! : nothing ,
4731
+ initializeprobmap = __has_initializeprobmap (f) ? f. initializeprobmap : nothing ,
4732
+ initializeprobpmap = __has_initializeprobpmap (f) ? f. initializeprobpmap : nothing ,
4733
+ initialization_data = __has_initialization_data (f) ? f. initialization_data :
4734
+ nothing ,
4735
+ nlprob_data = __has_nlprob_data (f) ? f. nlprob_data : nothing
4736
+ ) where {iip,
4737
+ specialize
4738
+ }
4739
+ if mass_matrix === I && f isa Tuple
4740
+ mass_matrix = ((I for i in 1 : length (f)). .. ,)
4741
+ end
4742
+
4743
+ if (specialize === FunctionWrapperSpecialize) &&
4744
+ ! (f isa FunctionWrappersWrappers. FunctionWrappersWrapper)
4745
+ error (" FunctionWrapperSpecialize must be used on the problem constructor for access to u0, p, and t types!" )
4746
+ end
4747
+
4748
+ if jac === nothing && isa (jac_prototype, AbstractSciMLOperator)
4749
+ if iip
4750
+ jac = update_coefficients! # (J,u,p,t)
4751
+ else
4752
+ jac = (u, p, t) -> update_coefficients (deepcopy (jac_prototype), u, p, t)
4753
+ end
4754
+ end
4755
+
4756
+ if controljac === nothing && isa (controljac_prototype, AbstractSciMLOperator)
4757
+ if iip_bc
4758
+ controljac = update_coefficients! # (J,u,p,t)
4759
+ else
4760
+ controljac = (u, p, t) -> update_coefficients! (deepcopy (controljac_prototype), u, p, t)
4761
+ end
4762
+ end
4763
+
4764
+ if jac_prototype != = nothing && colorvec === nothing &&
4765
+ ArrayInterface. fast_matrix_colors (jac_prototype)
4766
+ _colorvec = ArrayInterface. matrix_colors (jac_prototype)
4767
+ else
4768
+ _colorvec = colorvec
4769
+ end
4770
+
4771
+ jaciip = jac != = nothing ? isinplace (jac, 4 , " jac" , iip) : iip
4772
+ controljaciip = controljac != = nothing ? isinplace (controljac, 4 , " controljac" , iip) : iip
4773
+ tgradiip = tgrad != = nothing ? isinplace (tgrad, 4 , " tgrad" , iip) : iip
4774
+ jvpiip = jvp != = nothing ? isinplace (jvp, 5 , " jvp" , iip) : iip
4775
+ vjpiip = vjp != = nothing ? isinplace (vjp, 5 , " vjp" , iip) : iip
4776
+ Wfactiip = Wfact != = nothing ? isinplace (Wfact, 5 , " Wfact" , iip) : iip
4777
+ Wfact_tiip = Wfact_t != = nothing ? isinplace (Wfact_t, 5 , " Wfact_t" , iip) : iip
4778
+ paramjaciip = paramjac != = nothing ? isinplace (paramjac, 4 , " paramjac" , iip) : iip
4779
+
4780
+ nonconforming = (jaciip, tgradiip, jvpiip, vjpiip, Wfactiip, Wfact_tiip,
4781
+ paramjaciip) .!= iip
4782
+ if any (nonconforming)
4783
+ nonconforming = findall (nonconforming)
4784
+ functions = [" jac" , " tgrad" , " jvp" , " vjp" , " Wfact" , " Wfact_t" , " paramjac" ][nonconforming]
4785
+ throw (NonconformingFunctionsError (functions))
4786
+ end
4787
+
4788
+ _f = prepare_function (f)
4789
+
4790
+ sys = sys_or_symbolcache (sys, syms, paramsyms, indepsym)
4791
+ initdata = reconstruct_initialization_data (
4792
+ initialization_data, initializeprob, update_initializeprob!,
4793
+ initializeprobmap, initializeprobpmap)
4794
+
4795
+ if specialize === NoSpecialize
4796
+ ControlFunction{iip, specialize,
4797
+ Any, Any, Any, Any,
4798
+ Any, Any, Any, Any, typeof (jac_prototype), typeof (controljac_prototype),
4799
+ typeof (sparsity), Any, Any, typeof (W_prototype), Any,
4800
+ Any,
4801
+ typeof (_colorvec),
4802
+ typeof (sys), Union{Nothing, OverrideInitData}}(
4803
+ _f, mass_matrix, analytic, tgrad, jac, controljac,
4804
+ jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4805
+ Wfact_t, W_prototype, paramjac,
4806
+ observed, _colorvec, sys, initdata)
4807
+ elseif specialize === false
4808
+ ControlFunction{iip, FunctionWrapperSpecialize,
4809
+ typeof (_f), typeof (mass_matrix), typeof (analytic), typeof (tgrad),
4810
+ typeof (jac), typeof (controljac), typeof (jvp), typeof (vjp), typeof (jac_prototype), typeof (controljac_prototype),
4811
+ typeof (sparsity), typeof (Wfact), typeof (Wfact_t), typeof (W_prototype),
4812
+ typeof (paramjac),
4813
+ typeof (observed),
4814
+ typeof (_colorvec),
4815
+ typeof (sys), typeof (initdata)}(_f, mass_matrix,
4816
+ analytic, tgrad, jac, controljac,
4817
+ jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4818
+ Wfact_t, W_prototype, paramjac,
4819
+ observed, _colorvec, sys, initdata)
4820
+ else
4821
+ ControlFunction{iip, specialize,
4822
+ typeof (_f), typeof (mass_matrix), typeof (analytic), typeof (tgrad),
4823
+ typeof (jac), typeof (controljac), typeof (jvp), typeof (vjp), typeof (jac_prototype), typeof (controljac_prototype),
4824
+ typeof (sparsity), typeof (Wfact), typeof (Wfact_t), typeof (W_prototype),
4825
+ typeof (paramjac),
4826
+ typeof (observed),
4827
+ typeof (_colorvec),
4828
+ typeof (sys), typeof (initdata)}(
4829
+ _f, mass_matrix, analytic, tgrad,
4830
+ jac, controljac, jvp, vjp, jac_prototype, controljac_prototype, sparsity, Wfact,
4831
+ Wfact_t, W_prototype, paramjac,
4832
+ observed, _colorvec, sys, initdata)
4833
+ end
4834
+ end
4835
+
4836
+ function ODEFunction {iip} (f; kwargs... ) where {iip}
4837
+ ODEFunction {iip, FullSpecialize} (f; kwargs... )
4838
+ end
4839
+ ODEFunction {iip} (f:: ODEFunction ; kwargs... ) where {iip} = f
4840
+ ODEFunction (f; kwargs... ) = ODEFunction {isinplace(f, 4), FullSpecialize} (f; kwargs... )
4841
+ ODEFunction (f:: ODEFunction ; kwargs... ) = f
4842
+
4701
4843
# ######### Utility functions
4702
4844
4703
4845
function sys_or_symbolcache (sys, syms, paramsyms, indepsym = nothing )
@@ -4731,6 +4873,7 @@ __has_Wfact_t(f) = isdefined(f, :Wfact_t)
4731
4873
__has_W_prototype (f) = isdefined (f, :W_prototype )
4732
4874
__has_paramjac (f) = isdefined (f, :paramjac )
4733
4875
__has_jac_prototype (f) = isdefined (f, :jac_prototype )
4876
+ __has_controljac_prototype (f) = isdefined (f, :controljac_prototype )
4734
4877
__has_sparsity (f) = isdefined (f, :sparsity )
4735
4878
__has_mass_matrix (f) = isdefined (f, :mass_matrix )
4736
4879
__has_syms (f) = isdefined (f, :syms )
0 commit comments