@@ -331,6 +331,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
331331 analytic = nothing ,
332332 split_idxs = nothing ,
333333 initializeprob = nothing ,
334+ update_initializeprob! = nothing ,
334335 initializeprobmap = nothing ,
335336 initializeprobpmap = nothing ,
336337 kwargs... ) where {iip, specialize}
@@ -434,6 +435,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
434435 sparsity = sparsity ? jacobian_sparsity (sys) : nothing ,
435436 analytic = analytic,
436437 initializeprob = initializeprob,
438+ update_initializeprob! = update_initializeprob!,
437439 initializeprobmap = initializeprobmap,
438440 initializeprobpmap = initializeprobpmap)
439441end
@@ -778,6 +780,17 @@ function (f::GetUpdatedMTKParameters)(prob, initializesol)
778780 mtkp
779781end
780782
783+ struct UpdateInitializeprob{G, S}
784+ # `getu` functor which gets all values from prob
785+ getvals:: G
786+ # `setu` functor which updates initializeprob with values
787+ setvals:: S
788+ end
789+
790+ function (f:: UpdateInitializeprob )(initializeprob, prob)
791+ f. setvals (initializeprob, f. getvals (prob))
792+ end
793+
781794function get_temporary_value (p)
782795 stype = symtype (unwrap (p))
783796 return if stype == Real
@@ -866,6 +879,10 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
866879 getpunknowns = getu (initializeprob, punknowns)
867880 setpunknowns = setp (sys, punknowns)
868881 initializeprobpmap = GetUpdatedMTKParameters (getpunknowns, setpunknowns)
882+ reqd_syms = vcat (
883+ variable_symbols (initializeprob), parameter_symbols (initializeprob))
884+ update_initializeprob! = UpdateInitializeprob (
885+ getu (sys, reqd_syms), setu (initializeprob, reqd_syms))
869886
870887 zerovars = Dict (setdiff (unknowns (sys), keys (defaults (sys))) .=> 0.0 )
871888 if parammap isa SciMLBase. NullParameters
@@ -881,6 +898,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
881898 (trueinit = SVector {length(trueinit)} (trueinit))
882899 else
883900 initializeprob = nothing
901+ update_initializeprob! = nothing
884902 initializeprobmap = nothing
885903 initializeprobpmap = nothing
886904 trueinit = u0map
@@ -930,6 +948,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
930948 sparse = sparse, eval_expression = eval_expression,
931949 eval_module = eval_module,
932950 initializeprob = initializeprob,
951+ update_initializeprob! = update_initializeprob!,
933952 initializeprobmap = initializeprobmap,
934953 initializeprobpmap = initializeprobpmap,
935954 kwargs... )
0 commit comments