@@ -112,20 +112,24 @@ function remake(prob::ODEProblem; f = missing,
112112 p = missing ,
113113 kwargs = missing ,
114114 interpret_symbolicmap = true ,
115+ build_initializeprob = true ,
115116 use_defaults = false ,
116117 _kwargs... )
117118 if tspan === missing
118119 tspan = prob. tspan
119120 end
120121
121- newu0, newp = updated_u0_p (prob, u0, p; interpret_symbolicmap, use_defaults)
122+ newu0, newp = updated_u0_p (prob, u0, p, tspan[ 1 ] ; interpret_symbolicmap, use_defaults)
122123
123124 iip = isinplace (prob)
124125
125126 if f === missing
126- initializeprob, initializeprobmap = remake_initializeprob (
127- prob. f. sys, prob. f, u0 === missing ? newu0 : u0,
128- tspan[1 ], p === missing ? newp : p)
127+ if build_initializeprob
128+ initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap = remake_initializeprob (
129+ prob. f. sys, prob. f, u0, tspan[1 ], p)
130+ else
131+ initializeprob = update_initializeprob! = initializeprobmap = initializeprobpmap = nothing
132+ end
129133 if specialization (prob. f) === FunctionWrapperSpecialize
130134 ptspan = promote_tspan (tspan)
131135 if iip
@@ -134,14 +138,14 @@ function remake(prob::ODEProblem; f = missing,
134138 unwrapped_f (prob. f. f),
135139 (newu0, newu0, newp,
136140 ptspan[1 ]));
137- initializeprob, initializeprobmap)
141+ initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap )
138142 else
139143 _f = ODEFunction {iip, FunctionWrapperSpecialize} (
140144 wrapfun_oop (
141145 unwrapped_f (prob. f. f),
142146 (newu0, newp,
143147 ptspan[1 ]));
144- initializeprob, initializeprobmap)
148+ initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap )
145149 end
146150 else
147151 _f = prob. f
@@ -152,13 +156,27 @@ function remake(prob::ODEProblem; f = missing,
152156 _f = parameterless_type (_f){
153157 iip, specialization (_f), map (typeof, props)... }(props... )
154158 end
159+ if __has_update_initializeprob! (_f)
160+ props = getproperties (_f)
161+ @reset props. update_initializeprob! = update_initializeprob!
162+ props = values (props)
163+ _f = parameterless_type (_f){
164+ iip, specialization (_f), map (typeof, props)... }(props... )
165+ end
155166 if __has_initializeprobmap (_f)
156167 props = getproperties (_f)
157168 @reset props. initializeprobmap = initializeprobmap
158169 props = values (props)
159170 _f = parameterless_type (_f){
160171 iip, specialization (_f), map (typeof, props)... }(props... )
161172 end
173+ if __has_initializeprobpmap (_f)
174+ props = getproperties (_f)
175+ @reset props. initializeprobpmap = initializeprobpmap
176+ props = values (props)
177+ _f = parameterless_type (_f){
178+ iip, specialization (_f), map (typeof, props)... }(props... )
179+ end
162180 end
163181 elseif f isa AbstractODEFunction
164182 _f = f
@@ -189,15 +207,20 @@ end
189207 remake_initializeprob(sys, scimlfn, u0, t0, p)
190208
191209Re-create the initialization problem present in the function `scimlfn`, using the
192- associated system `sys`, and the new values of `u0`, initial time `t0` and `p`. By
193- default, returns `nothing, nothing` if `scimlfn` does not have an initialization
194- problem, and `scimlfn.initializeprob, scimlfn.initializeprobmap` if it does.
210+ associated system `sys`, and the user-provided new values of `u0`, initial time `t0` and
211+ `p`. By default, returns `nothing, nothing, nothing, nothing` if `scimlfn` does not have an
212+ initialization problem, and
213+ `scimlfn.initializeprob, scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap`
214+ if it does.
215+
216+ Note that `u0` or `p` may be `missing` if the user does not provide a value for them.
195217"""
196218function remake_initializeprob (sys, scimlfn, u0, t0, p)
197219 if ! has_initializeprob (scimlfn)
198- return nothing , nothing
220+ return nothing , nothing , nothing , nothing
199221 end
200- return scimlfn. initializeprob, scimlfn. initializeprobmap
222+ return scimlfn. initializeprob,
223+ scimlfn. update_initializeprob!, scimlfn. initializeprobmap, scimlfn. initializeprobpmap
201224end
202225
203226"""
@@ -214,7 +237,7 @@ function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = miss
214237 tspan = prob. tspan
215238 end
216239
217- u0, p = updated_u0_p (prob, u0, p; interpret_symbolicmap, use_defaults)
240+ u0, p = updated_u0_p (prob, u0, p, tspan[ 1 ] ; interpret_symbolicmap, use_defaults)
218241
219242 if problem_type === missing
220243 problem_type = prob. problem_type
@@ -280,7 +303,7 @@ function remake(prob::SDEProblem;
280303 tspan = prob. tspan
281304 end
282305
283- u0, p = updated_u0_p (prob, u0, p; interpret_symbolicmap, use_defaults)
306+ u0, p = updated_u0_p (prob, u0, p, tspan[ 1 ] ; interpret_symbolicmap, use_defaults)
284307
285308 if noise === missing
286309 noise = prob. noise
@@ -496,35 +519,35 @@ anydict(d) = Dict{Any, Any}(d)
496519anydict () = Dict {Any, Any} ()
497520
498521function _updated_u0_p_internal (
499- prob, :: Missing , :: Missing ; interpret_symbolicmap = true , use_defaults = false )
522+ prob, :: Missing , :: Missing , t0 ; interpret_symbolicmap = true , use_defaults = false )
500523 return state_values (prob), parameter_values (prob)
501524end
502525function _updated_u0_p_internal (
503- prob, :: Missing , p; interpret_symbolicmap = true , use_defaults = false )
526+ prob, :: Missing , p, t0 ; interpret_symbolicmap = true , use_defaults = false )
504527 u0 = state_values (prob)
505528
506529 if p isa AbstractArray && isempty (p)
507530 return _updated_u0_p_internal (
508- prob, u0, parameter_values (prob); interpret_symbolicmap)
531+ prob, u0, parameter_values (prob), t0 ; interpret_symbolicmap)
509532 end
510533 eltype (p) <: Pair && interpret_symbolicmap || return u0, p
511534 defs = default_values (prob)
512535 p = fill_p (prob, anydict (p); defs, use_defaults)
513- return _updated_u0_p_symmap (prob, u0, Val (false ), p, Val (true ))
536+ return _updated_u0_p_symmap (prob, u0, Val (false ), p, Val (true ), t0 )
514537end
515538
516539function _updated_u0_p_internal (
517- prob, u0, :: Missing ; interpret_symbolicmap = true , use_defaults = false )
540+ prob, u0, :: Missing , t0 ; interpret_symbolicmap = true , use_defaults = false )
518541 p = parameter_values (prob)
519542
520543 eltype (u0) <: Pair || return u0, p
521544 defs = default_values (prob)
522545 u0 = fill_u0 (prob, anydict (u0); defs, use_defaults)
523- return _updated_u0_p_symmap (prob, u0, Val (true ), p, Val (false ))
546+ return _updated_u0_p_symmap (prob, u0, Val (true ), p, Val (false ), t0 )
524547end
525548
526549function _updated_u0_p_internal (
527- prob, u0, p; interpret_symbolicmap = true , use_defaults = false )
550+ prob, u0, p, t0 ; interpret_symbolicmap = true , use_defaults = false )
528551 isu0symbolic = eltype (u0) <: Pair
529552 ispsymbolic = eltype (p) <: Pair && interpret_symbolicmap
530553
@@ -538,7 +561,7 @@ function _updated_u0_p_internal(
538561 if ispsymbolic
539562 p = fill_p (prob, anydict (p); defs, use_defaults)
540563 end
541- return _updated_u0_p_symmap (prob, u0, Val (isu0symbolic), p, Val (ispsymbolic))
564+ return _updated_u0_p_symmap (prob, u0, Val (isu0symbolic), p, Val (ispsymbolic), t0 )
542565end
543566
544567function fill_u0 (prob, u0; defs = nothing , use_defaults = false )
@@ -629,7 +652,7 @@ function fill_p(prob, p; defs = nothing, use_defaults = false)
629652 return newvals
630653end
631654
632- function _updated_u0_p_symmap (prob, u0, :: Val{true} , p, :: Val{false} )
655+ function _updated_u0_p_symmap (prob, u0, :: Val{true} , p, :: Val{false} , t0 )
633656 isdep = any (symbolic_type (v) != = NotSymbolic () for (_, v) in u0)
634657 isdep || return remake_buffer (prob, state_values (prob), keys (u0), values (u0)), p
635658
@@ -642,13 +665,13 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false})
642665 # FIXME : need to provide `u` since the observed function expects it.
643666 # This is sort of an implicit dependency on MTK. The values of `u` won't actually be
644667 # used, since any state symbols in the expression were substituted out earlier.
645- temp_state = ProblemState (; u = state_values (prob), p = p)
668+ temp_state = ProblemState (; u = state_values (prob), p = p, t = t0 )
646669 u0 = anydict (k => symbolic_type (v) === NotSymbolic () ? v : getu (prob, v)(temp_state)
647670 for (k, v) in u0)
648671 return remake_buffer (prob, state_values (prob), keys (u0), values (u0)), p
649672end
650673
651- function _updated_u0_p_symmap (prob, u0, :: Val{false} , p, :: Val{true} )
674+ function _updated_u0_p_symmap (prob, u0, :: Val{false} , p, :: Val{true} , t0 )
652675 isdep = any (symbolic_type (v) != = NotSymbolic () for (_, v) in p)
653676 isdep || return u0, remake_buffer (prob, parameter_values (prob), keys (p), values (p))
654677
@@ -661,13 +684,13 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true})
661684 # FIXME : need to provide `p` since the observed function expects an `MTKParameters`
662685 # this is sort of an implicit dependency on MTK. The values of `p` won't actually be
663686 # used, since any parameter symbols in the expression were substituted out earlier.
664- temp_state = ProblemState (; u = u0, p = parameter_values (prob))
687+ temp_state = ProblemState (; u = u0, p = parameter_values (prob), t = t0 )
665688 p = anydict (k => symbolic_type (v) === NotSymbolic () ? v : getu (prob, v)(temp_state)
666689 for (k, v) in p)
667690 return u0, remake_buffer (prob, parameter_values (prob), keys (p), values (p))
668691end
669692
670- function _updated_u0_p_symmap (prob, u0, :: Val{true} , p, :: Val{true} )
693+ function _updated_u0_p_symmap (prob, u0, :: Val{true} , p, :: Val{true} , t0 )
671694 isu0dep = any (symbolic_type (v) != = NotSymbolic () for (_, v) in u0)
672695 ispdep = any (symbolic_type (v) != = NotSymbolic () for (_, v) in p)
673696
@@ -677,11 +700,11 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
677700 end
678701 if ! isu0dep
679702 u0 = remake_buffer (prob, state_values (prob), keys (u0), values (u0))
680- return _updated_u0_p_symmap (prob, u0, Val (false ), p, Val (true ))
703+ return _updated_u0_p_symmap (prob, u0, Val (false ), p, Val (true ), t0 )
681704 end
682705 if ! ispdep
683706 p = remake_buffer (prob, parameter_values (prob), keys (p), values (p))
684- return _updated_u0_p_symmap (prob, u0, Val (true ), p, Val (false ))
707+ return _updated_u0_p_symmap (prob, u0, Val (true ), p, Val (false ), t0 )
685708 end
686709
687710 varmap = merge (u0, p)
@@ -693,7 +716,8 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
693716 remake_buffer (prob, parameter_values (prob), keys (p), values (p))
694717end
695718
696- function updated_u0_p (prob, u0, p; interpret_symbolicmap = true , use_defaults = false )
719+ function updated_u0_p (
720+ prob, u0, p, t0 = nothing ; interpret_symbolicmap = true , use_defaults = false )
697721 if u0 === missing && p === missing
698722 return state_values (prob), parameter_values (prob)
699723 end
@@ -712,7 +736,7 @@ function updated_u0_p(prob, u0, p; interpret_symbolicmap = true, use_defaults =
712736 return (u0 === missing ? state_values (prob) : u0),
713737 (p === missing ? parameter_values (prob) : p)
714738 end
715- return _updated_u0_p_internal (prob, u0, p; interpret_symbolicmap, use_defaults)
739+ return _updated_u0_p_internal (prob, u0, p, t0 ; interpret_symbolicmap, use_defaults)
716740end
717741
718742# overloaded in MTK to intercept symbolic remake
0 commit comments