Skip to content

Commit 16815db

Browse files
Merge pull request #760 from AayushSabharwal/as/param-init
feat: support `initializeprobpmap` in relevant `SciMLFunctions`
2 parents 70ae7fb + a25ee6a commit 16815db

File tree

3 files changed

+138
-62
lines changed

3 files changed

+138
-62
lines changed

src/problems/ode_problems.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,17 @@ function Base.setproperty!(prob::ODEProblem, s::Symbol, v, order::Symbol)
174174
Base.setfield!(prob, s, v, order)
175175
end
176176

177+
function ConstructionBase.constructorof(::Type{P}) where {P <: ODEProblem}
178+
function ctor(f, u0, tspan, p, kw, pt)
179+
if f isa AbstractODEFunction
180+
iip = isinplace(f)
181+
else
182+
iip = isinplace(f, 4)
183+
end
184+
return ODEProblem{iip}(f, u0, tspan, p, pt; kw...)
185+
end
186+
end
187+
177188
"""
178189
ODEProblem(f::ODEFunction,u0,tspan,p=NullParameters(),callback=CallbackSet())
179190

src/remake.jl

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -112,20 +112,24 @@ function remake(prob::ODEProblem; f = missing,
112112
p = missing,
113113
kwargs = missing,
114114
interpret_symbolicmap = true,
115+
build_initializeprob = true,
115116
use_defaults = false,
116117
_kwargs...)
117118
if tspan === missing
118119
tspan = prob.tspan
119120
end
120121

121-
newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
122+
newu0, newp = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults)
122123

123124
iip = isinplace(prob)
124125

125126
if f === missing
126-
initializeprob, initializeprobmap = remake_initializeprob(
127-
prob.f.sys, prob.f, u0 === missing ? newu0 : u0,
128-
tspan[1], p === missing ? newp : p)
127+
if build_initializeprob
128+
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap = remake_initializeprob(
129+
prob.f.sys, prob.f, u0, tspan[1], p)
130+
else
131+
initializeprob = update_initializeprob! = initializeprobmap = initializeprobpmap = nothing
132+
end
129133
if specialization(prob.f) === FunctionWrapperSpecialize
130134
ptspan = promote_tspan(tspan)
131135
if iip
@@ -134,14 +138,14 @@ function remake(prob::ODEProblem; f = missing,
134138
unwrapped_f(prob.f.f),
135139
(newu0, newu0, newp,
136140
ptspan[1]));
137-
initializeprob, initializeprobmap)
141+
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
138142
else
139143
_f = ODEFunction{iip, FunctionWrapperSpecialize}(
140144
wrapfun_oop(
141145
unwrapped_f(prob.f.f),
142146
(newu0, newp,
143147
ptspan[1]));
144-
initializeprob, initializeprobmap)
148+
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
145149
end
146150
else
147151
_f = prob.f
@@ -152,13 +156,27 @@ function remake(prob::ODEProblem; f = missing,
152156
_f = parameterless_type(_f){
153157
iip, specialization(_f), map(typeof, props)...}(props...)
154158
end
159+
if __has_update_initializeprob!(_f)
160+
props = getproperties(_f)
161+
@reset props.update_initializeprob! = update_initializeprob!
162+
props = values(props)
163+
_f = parameterless_type(_f){
164+
iip, specialization(_f), map(typeof, props)...}(props...)
165+
end
155166
if __has_initializeprobmap(_f)
156167
props = getproperties(_f)
157168
@reset props.initializeprobmap = initializeprobmap
158169
props = values(props)
159170
_f = parameterless_type(_f){
160171
iip, specialization(_f), map(typeof, props)...}(props...)
161172
end
173+
if __has_initializeprobpmap(_f)
174+
props = getproperties(_f)
175+
@reset props.initializeprobpmap = initializeprobpmap
176+
props = values(props)
177+
_f = parameterless_type(_f){
178+
iip, specialization(_f), map(typeof, props)...}(props...)
179+
end
162180
end
163181
elseif f isa AbstractODEFunction
164182
_f = f
@@ -189,15 +207,20 @@ end
189207
remake_initializeprob(sys, scimlfn, u0, t0, p)
190208
191209
Re-create the initialization problem present in the function `scimlfn`, using the
192-
associated system `sys`, and the new values of `u0`, initial time `t0` and `p`. By
193-
default, returns `nothing, nothing` if `scimlfn` does not have an initialization
194-
problem, and `scimlfn.initializeprob, scimlfn.initializeprobmap` if it does.
210+
associated system `sys`, and the user-provided new values of `u0`, initial time `t0` and
211+
`p`. By default, returns `nothing, nothing, nothing, nothing` if `scimlfn` does not have an
212+
initialization problem, and
213+
`scimlfn.initializeprob, scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap`
214+
if it does.
215+
216+
Note that `u0` or `p` may be `missing` if the user does not provide a value for them.
195217
"""
196218
function remake_initializeprob(sys, scimlfn, u0, t0, p)
197219
if !has_initializeprob(scimlfn)
198-
return nothing, nothing
220+
return nothing, nothing, nothing, nothing
199221
end
200-
return scimlfn.initializeprob, scimlfn.initializeprobmap
222+
return scimlfn.initializeprob,
223+
scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap
201224
end
202225

203226
"""
@@ -214,7 +237,7 @@ function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = miss
214237
tspan = prob.tspan
215238
end
216239

217-
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
240+
u0, p = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults)
218241

219242
if problem_type === missing
220243
problem_type = prob.problem_type
@@ -280,7 +303,7 @@ function remake(prob::SDEProblem;
280303
tspan = prob.tspan
281304
end
282305

283-
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
306+
u0, p = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults)
284307

285308
if noise === missing
286309
noise = prob.noise
@@ -496,35 +519,35 @@ anydict(d) = Dict{Any, Any}(d)
496519
anydict() = Dict{Any, Any}()
497520

498521
function _updated_u0_p_internal(
499-
prob, ::Missing, ::Missing; interpret_symbolicmap = true, use_defaults = false)
522+
prob, ::Missing, ::Missing, t0; interpret_symbolicmap = true, use_defaults = false)
500523
return state_values(prob), parameter_values(prob)
501524
end
502525
function _updated_u0_p_internal(
503-
prob, ::Missing, p; interpret_symbolicmap = true, use_defaults = false)
526+
prob, ::Missing, p, t0; interpret_symbolicmap = true, use_defaults = false)
504527
u0 = state_values(prob)
505528

506529
if p isa AbstractArray && isempty(p)
507530
return _updated_u0_p_internal(
508-
prob, u0, parameter_values(prob); interpret_symbolicmap)
531+
prob, u0, parameter_values(prob), t0; interpret_symbolicmap)
509532
end
510533
eltype(p) <: Pair && interpret_symbolicmap || return u0, p
511534
defs = default_values(prob)
512535
p = fill_p(prob, anydict(p); defs, use_defaults)
513-
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true))
536+
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true), t0)
514537
end
515538

516539
function _updated_u0_p_internal(
517-
prob, u0, ::Missing; interpret_symbolicmap = true, use_defaults = false)
540+
prob, u0, ::Missing, t0; interpret_symbolicmap = true, use_defaults = false)
518541
p = parameter_values(prob)
519542

520543
eltype(u0) <: Pair || return u0, p
521544
defs = default_values(prob)
522545
u0 = fill_u0(prob, anydict(u0); defs, use_defaults)
523-
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false))
546+
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false), t0)
524547
end
525548

526549
function _updated_u0_p_internal(
527-
prob, u0, p; interpret_symbolicmap = true, use_defaults = false)
550+
prob, u0, p, t0; interpret_symbolicmap = true, use_defaults = false)
528551
isu0symbolic = eltype(u0) <: Pair
529552
ispsymbolic = eltype(p) <: Pair && interpret_symbolicmap
530553

@@ -538,7 +561,7 @@ function _updated_u0_p_internal(
538561
if ispsymbolic
539562
p = fill_p(prob, anydict(p); defs, use_defaults)
540563
end
541-
return _updated_u0_p_symmap(prob, u0, Val(isu0symbolic), p, Val(ispsymbolic))
564+
return _updated_u0_p_symmap(prob, u0, Val(isu0symbolic), p, Val(ispsymbolic), t0)
542565
end
543566

544567
function fill_u0(prob, u0; defs = nothing, use_defaults = false)
@@ -629,7 +652,7 @@ function fill_p(prob, p; defs = nothing, use_defaults = false)
629652
return newvals
630653
end
631654

632-
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false})
655+
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0)
633656
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
634657
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
635658

@@ -642,13 +665,13 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false})
642665
# FIXME: need to provide `u` since the observed function expects it.
643666
# This is sort of an implicit dependency on MTK. The values of `u` won't actually be
644667
# used, since any state symbols in the expression were substituted out earlier.
645-
temp_state = ProblemState(; u = state_values(prob), p = p)
668+
temp_state = ProblemState(; u = state_values(prob), p = p, t = t0)
646669
u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
647670
for (k, v) in u0)
648671
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
649672
end
650673

651-
function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true})
674+
function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0)
652675
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
653676
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
654677

@@ -661,13 +684,13 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true})
661684
# FIXME: need to provide `p` since the observed function expects an `MTKParameters`
662685
# this is sort of an implicit dependency on MTK. The values of `p` won't actually be
663686
# used, since any parameter symbols in the expression were substituted out earlier.
664-
temp_state = ProblemState(; u = u0, p = parameter_values(prob))
687+
temp_state = ProblemState(; u = u0, p = parameter_values(prob), t = t0)
665688
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
666689
for (k, v) in p)
667690
return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
668691
end
669692

670-
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
693+
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
671694
isu0dep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
672695
ispdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
673696

@@ -677,11 +700,11 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
677700
end
678701
if !isu0dep
679702
u0 = remake_buffer(prob, state_values(prob), keys(u0), values(u0))
680-
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true))
703+
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true), t0)
681704
end
682705
if !ispdep
683706
p = remake_buffer(prob, parameter_values(prob), keys(p), values(p))
684-
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false))
707+
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false), t0)
685708
end
686709

687710
varmap = merge(u0, p)
@@ -693,7 +716,8 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
693716
remake_buffer(prob, parameter_values(prob), keys(p), values(p))
694717
end
695718

696-
function updated_u0_p(prob, u0, p; interpret_symbolicmap = true, use_defaults = false)
719+
function updated_u0_p(
720+
prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false)
697721
if u0 === missing && p === missing
698722
return state_values(prob), parameter_values(prob)
699723
end
@@ -712,7 +736,7 @@ function updated_u0_p(prob, u0, p; interpret_symbolicmap = true, use_defaults =
712736
return (u0 === missing ? state_values(prob) : u0),
713737
(p === missing ? parameter_values(prob) : p)
714738
end
715-
return _updated_u0_p_internal(prob, u0, p; interpret_symbolicmap, use_defaults)
739+
return _updated_u0_p_internal(prob, u0, p, t0; interpret_symbolicmap, use_defaults)
716740
end
717741

718742
# overloaded in MTK to intercept symbolic remake

0 commit comments

Comments
 (0)