@@ -280,6 +280,25 @@ function isautonomous(sys::AbstractODESystem)
280280 all (iszero, tgrad)
281281end
282282
283+ struct GetAndSetFunctor{G, S}
284+ getter:: G
285+ setter:: S
286+ end
287+
288+ function (gs:: GetAndSetFunctor )(dest, source)
289+ gs. setter (dest, gs. getter (source))
290+ end
291+
292+ function generate_initializeprob_init (sys:: AbstractSystem , initsys:: AbstractSystem )
293+ syms = vcat (variable_symbols (initsys), parameter_symbols (initsys))
294+ return GetAndSetFunctor (getu (sys, syms), setu (initsys, syms))
295+ end
296+
297+ function generate_initializeprob_update (sys:: AbstractSystem , initsys:: AbstractSystem )
298+ syms = vcat (variable_symbols (sys), parameter_symbols (sys))
299+ return GetAndSetFunctor (getu (initsys, syms), setu (sys, syms))
300+ end
301+
283302"""
284303```julia
285304DiffEqBase.ODEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys),
@@ -323,8 +342,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
323342 analytic = nothing ,
324343 split_idxs = nothing ,
325344 initializeprob = nothing ,
326- initializeprobmap = nothing ,
327- initializeprob_updatep ! = nothing ,
345+ initializeprob_init! = nothing ,
346+ initializeprob_update ! = nothing ,
328347 kwargs... ) where {iip, specialize}
329348 if ! iscomplete (sys)
330349 error (" A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`" )
@@ -507,8 +526,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
507526 sparsity = sparsity ? jacobian_sparsity (sys) : nothing ,
508527 analytic = analytic,
509528 initializeprob = initializeprob,
510- initializeprobmap = initializeprobmap ,
511- initializeprob_updatep ! = initializeprob_updatep !)
529+ initializeprob_init! = initializeprob_init! ,
530+ initializeprob_update ! = initializeprob_update !)
512531end
513532
514533"""
@@ -539,8 +558,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
539558 eval_module = @__MODULE__ ,
540559 checkbounds = false ,
541560 initializeprob = nothing ,
542- initializeprobmap = nothing ,
543- initializeprob_updatep ! = nothing ,
561+ initializeprob_init! = nothing ,
562+ initializeprob_update ! = nothing ,
544563 kwargs... ) where {iip}
545564 if ! iscomplete (sys)
546565 error (" A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`" )
@@ -614,8 +633,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
614633 jac_prototype = jac_prototype,
615634 observed = observedfun,
616635 initializeprob = initializeprob,
617- initializeprobmap = initializeprobmap ,
618- initializeprob_updatep ! = initializeprob_updatep !)
636+ initializeprob_init! = initializeprob_init! ,
637+ initializeprob_update ! = initializeprob_update !)
619638end
620639
621640function DiffEqBase. DDEFunction (sys:: AbstractODESystem , args... ; kwargs... )
@@ -927,26 +946,11 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
927946 end
928947 initializeprob = ModelingToolkit. InitializationProblem (
929948 sys, t, u0map, parammap; guesses, warn_initialize_determined)
930- unks = unknowns (sys)
931- initializeprobmap = isempty (unks) ? (_... ) -> nothing :
932- getu (initializeprob, unknowns (sys))
933- if any (p -> is_variable (initializeprob, p) || is_observed (initializeprob, p),
934- parameters (sys))
935- punknowns = [p
936- for p in parameters (sys)
937- if is_variable (initializeprob, p) ||
938- is_observed (initializeprob, p)]
939- initializeprob_updatep! = let getter = getu (initializeprob, tovar .(punknowns)),
940- setter = setp (sys, punknowns)
941-
942- function (ps, initsol)
943- setter (ps, getter (initsol))
944- end
945- end
946- else
947- punknowns = []
948- initializeprob_updatep! = nothing
949- end
949+ punknowns = [p
950+ for p in parameters (sys)
951+ if is_variable (initializeprob, p) || is_observed (initializeprob, p)]
952+ initializeprob_init! = generate_initializeprob_init (sys, initializeprob. f. sys)
953+ initializeprob_update! = generate_initializeprob_update (sys, initializeprob. f. sys)
950954 zerovars = Dict (setdiff (unknowns (sys), keys (defaults (sys))) .=> 0.0 )
951955 zeropars = Dict ()
952956 for p in punknowns
@@ -961,9 +965,9 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
961965 (trueinit = SVector {length(trueinit)} (trueinit))
962966 else
963967 initializeprob = nothing
964- initializeprobmap = nothing
965- initializeprob_updatep! = nothing
966968 zeropars = Dict ()
969+ initializeprob_init! = nothing
970+ initializeprob_update! = nothing
967971 trueinit = u0map
968972 end
969973
@@ -1012,9 +1016,8 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
10121016 checkbounds = checkbounds, p = p,
10131017 linenumbers = linenumbers, parallel = parallel, simplify = simplify,
10141018 sparse = sparse, eval_expression = eval_expression,
1015- initializeprob = initializeprob,
1016- initializeprobmap = initializeprobmap,
1017- initializeprob_updatep! = initializeprob_updatep!,
1019+ initializeprob = initializeprob, initializeprob_init! = initializeprob_init!,
1020+ initializeprob_update! = initializeprob_update!,
10181021 kwargs... )
10191022 implicit_dae ? (f, du0, u0, p) : (f, u0, p)
10201023end
0 commit comments