@@ -331,6 +331,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
331331        analytic =  nothing ,
332332        split_idxs =  nothing ,
333333        initializeprob =  nothing ,
334+         update_initializeprob! =  nothing ,
334335        initializeprobmap =  nothing ,
335336        initializeprobpmap =  nothing ,
336337        kwargs... ) where  {iip, specialize}
@@ -434,6 +435,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
434435        sparsity =  sparsity ?  jacobian_sparsity (sys) :  nothing ,
435436        analytic =  analytic,
436437        initializeprob =  initializeprob,
438+         update_initializeprob! =  update_initializeprob!,
437439        initializeprobmap =  initializeprobmap,
438440        initializeprobpmap =  initializeprobpmap)
439441end 
@@ -778,6 +780,17 @@ function (f::GetUpdatedMTKParameters)(prob, initializesol)
778780    mtkp
779781end 
780782
783+ struct  UpdateInitializeprob{G, S}
784+     #  `getu` functor which gets all values from prob
785+     getvals:: G 
786+     #  `setu` functor which updates initializeprob with values
787+     setvals:: S 
788+ end 
789+ 
790+ function  (f:: UpdateInitializeprob )(initializeprob, prob)
791+     f. setvals (initializeprob, f. getvals (prob))
792+ end 
793+ 
781794function  get_temporary_value (p)
782795    stype =  symtype (unwrap (p))
783796    return  if  stype ==  Real
@@ -865,6 +878,10 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
865878        getpunknowns =  getu (initializeprob, punknowns)
866879        setpunknowns =  setp (sys, punknowns)
867880        initializeprobpmap =  GetUpdatedMTKParameters (getpunknowns, setpunknowns)
881+         reqd_syms =  vcat (
882+             variable_symbols (initializeprob), parameter_symbols (initializeprob))
883+         update_initializeprob! =  UpdateInitializeprob (
884+             getu (sys, reqd_syms), setu (initializeprob, reqd_syms))
868885
869886        zerovars =  Dict (setdiff (unknowns (sys), keys (defaults (sys))) .=>  0.0 )
870887        if  parammap isa  SciMLBase. NullParameters
@@ -880,6 +897,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
880897            (trueinit =  SVector {length(trueinit)} (trueinit))
881898    else 
882899        initializeprob =  nothing 
900+         update_initializeprob! =  nothing 
883901        initializeprobmap =  nothing 
884902        initializeprobpmap =  nothing 
885903        trueinit =  u0map
@@ -929,6 +947,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
929947        sparse =  sparse, eval_expression =  eval_expression,
930948        eval_module =  eval_module,
931949        initializeprob =  initializeprob,
950+         update_initializeprob! =  update_initializeprob!,
932951        initializeprobmap =  initializeprobmap,
933952        initializeprobpmap =  initializeprobpmap,
934953        kwargs... )
0 commit comments