Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/problems/ode_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,17 @@ function Base.setproperty!(prob::ODEProblem, s::Symbol, v, order::Symbol)
Base.setfield!(prob, s, v, order)
end

function ConstructionBase.constructorof(::Type{P}) where {P <: ODEProblem}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For SciML/DiffEqBase.jl#1070, but it might not be now.

function ctor(f, u0, tspan, p, kw, pt)
if f isa AbstractODEFunction
iip = isinplace(f)
else
iip = isinplace(f, 4)
end
return ODEProblem{iip}(f, u0, tspan, p, pt; kw...)
end
end

"""
ODEProblem(f::ODEFunction,u0,tspan,p=NullParameters(),callback=CallbackSet())

Expand Down
84 changes: 54 additions & 30 deletions src/remake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,24 @@ function remake(prob::ODEProblem; f = missing,
p = missing,
kwargs = missing,
interpret_symbolicmap = true,
build_initializeprob = true,
use_defaults = false,
_kwargs...)
if tspan === missing
tspan = prob.tspan
end

newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
newu0, newp = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this breaking downstream?


iip = isinplace(prob)

if f === missing
initializeprob, initializeprobmap = remake_initializeprob(
prob.f.sys, prob.f, u0 === missing ? newu0 : u0,
tspan[1], p === missing ? newp : p)
if build_initializeprob
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap = remake_initializeprob(
prob.f.sys, prob.f, u0, tspan[1], p)
else
initializeprob = update_initializeprob! = initializeprobmap = initializeprobpmap = nothing
end
if specialization(prob.f) === FunctionWrapperSpecialize
ptspan = promote_tspan(tspan)
if iip
Expand All @@ -134,14 +138,14 @@ function remake(prob::ODEProblem; f = missing,
unwrapped_f(prob.f.f),
(newu0, newu0, newp,
ptspan[1]));
initializeprob, initializeprobmap)
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
else
_f = ODEFunction{iip, FunctionWrapperSpecialize}(
wrapfun_oop(
unwrapped_f(prob.f.f),
(newu0, newp,
ptspan[1]));
initializeprob, initializeprobmap)
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap)
end
else
_f = prob.f
Expand All @@ -152,13 +156,27 @@ function remake(prob::ODEProblem; f = missing,
_f = parameterless_type(_f){
iip, specialization(_f), map(typeof, props)...}(props...)
end
if __has_update_initializeprob!(_f)
props = getproperties(_f)
@reset props.update_initializeprob! = update_initializeprob!
props = values(props)
_f = parameterless_type(_f){
iip, specialization(_f), map(typeof, props)...}(props...)
end
if __has_initializeprobmap(_f)
props = getproperties(_f)
@reset props.initializeprobmap = initializeprobmap
props = values(props)
_f = parameterless_type(_f){
iip, specialization(_f), map(typeof, props)...}(props...)
end
if __has_initializeprobpmap(_f)
props = getproperties(_f)
@reset props.initializeprobpmap = initializeprobpmap
props = values(props)
_f = parameterless_type(_f){
iip, specialization(_f), map(typeof, props)...}(props...)
end
end
elseif f isa AbstractODEFunction
_f = f
Expand Down Expand Up @@ -189,15 +207,20 @@ end
remake_initializeprob(sys, scimlfn, u0, t0, p)

Re-create the initialization problem present in the function `scimlfn`, using the
associated system `sys`, and the new values of `u0`, initial time `t0` and `p`. By
default, returns `nothing, nothing` if `scimlfn` does not have an initialization
problem, and `scimlfn.initializeprob, scimlfn.initializeprobmap` if it does.
associated system `sys`, and the user-provided new values of `u0`, initial time `t0` and
`p`. By default, returns `nothing, nothing, nothing, nothing` if `scimlfn` does not have an
initialization problem, and
`scimlfn.initializeprob, scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap`
if it does.

Note that `u0` or `p` may be `missing` if the user does not provide a value for them.
"""
function remake_initializeprob(sys, scimlfn, u0, t0, p)
if !has_initializeprob(scimlfn)
return nothing, nothing
return nothing, nothing, nothing, nothing
end
return scimlfn.initializeprob, scimlfn.initializeprobmap
return scimlfn.initializeprob,
scimlfn.update_initializeprob!, scimlfn.initializeprobmap, scimlfn.initializeprobpmap
end

"""
Expand All @@ -214,7 +237,7 @@ function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = miss
tspan = prob.tspan
end

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
u0, p = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults)

if problem_type === missing
problem_type = prob.problem_type
Expand Down Expand Up @@ -280,7 +303,7 @@ function remake(prob::SDEProblem;
tspan = prob.tspan
end

u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
u0, p = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults)

if noise === missing
noise = prob.noise
Expand Down Expand Up @@ -496,35 +519,35 @@ anydict(d) = Dict{Any, Any}(d)
anydict() = Dict{Any, Any}()

function _updated_u0_p_internal(
prob, ::Missing, ::Missing; interpret_symbolicmap = true, use_defaults = false)
prob, ::Missing, ::Missing, t0; interpret_symbolicmap = true, use_defaults = false)
return state_values(prob), parameter_values(prob)
end
function _updated_u0_p_internal(
prob, ::Missing, p; interpret_symbolicmap = true, use_defaults = false)
prob, ::Missing, p, t0; interpret_symbolicmap = true, use_defaults = false)
u0 = state_values(prob)

if p isa AbstractArray && isempty(p)
return _updated_u0_p_internal(
prob, u0, parameter_values(prob); interpret_symbolicmap)
prob, u0, parameter_values(prob), t0; interpret_symbolicmap)
end
eltype(p) <: Pair && interpret_symbolicmap || return u0, p
defs = default_values(prob)
p = fill_p(prob, anydict(p); defs, use_defaults)
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true))
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true), t0)
end

function _updated_u0_p_internal(
prob, u0, ::Missing; interpret_symbolicmap = true, use_defaults = false)
prob, u0, ::Missing, t0; interpret_symbolicmap = true, use_defaults = false)
p = parameter_values(prob)

eltype(u0) <: Pair || return u0, p
defs = default_values(prob)
u0 = fill_u0(prob, anydict(u0); defs, use_defaults)
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false))
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false), t0)
end

function _updated_u0_p_internal(
prob, u0, p; interpret_symbolicmap = true, use_defaults = false)
prob, u0, p, t0; interpret_symbolicmap = true, use_defaults = false)
isu0symbolic = eltype(u0) <: Pair
ispsymbolic = eltype(p) <: Pair && interpret_symbolicmap

Expand All @@ -538,7 +561,7 @@ function _updated_u0_p_internal(
if ispsymbolic
p = fill_p(prob, anydict(p); defs, use_defaults)
end
return _updated_u0_p_symmap(prob, u0, Val(isu0symbolic), p, Val(ispsymbolic))
return _updated_u0_p_symmap(prob, u0, Val(isu0symbolic), p, Val(ispsymbolic), t0)
end

function fill_u0(prob, u0; defs = nothing, use_defaults = false)
Expand Down Expand Up @@ -629,7 +652,7 @@ function fill_p(prob, p; defs = nothing, use_defaults = false)
return newvals
end

function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false})
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0)
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p

Expand All @@ -642,13 +665,13 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false})
# FIXME: need to provide `u` since the observed function expects it.
# This is sort of an implicit dependency on MTK. The values of `u` won't actually be
# used, since any state symbols in the expression were substituted out earlier.
temp_state = ProblemState(; u = state_values(prob), p = p)
temp_state = ProblemState(; u = state_values(prob), p = p, t = t0)
u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
for (k, v) in u0)
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
end

function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true})
function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0)
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))

Expand All @@ -661,13 +684,13 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true})
# FIXME: need to provide `p` since the observed function expects an `MTKParameters`
# this is sort of an implicit dependency on MTK. The values of `p` won't actually be
# used, since any parameter symbols in the expression were substituted out earlier.
temp_state = ProblemState(; u = u0, p = parameter_values(prob))
temp_state = ProblemState(; u = u0, p = parameter_values(prob), t = t0)
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
for (k, v) in p)
return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
end

function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
isu0dep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
ispdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)

Expand All @@ -677,11 +700,11 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
end
if !isu0dep
u0 = remake_buffer(prob, state_values(prob), keys(u0), values(u0))
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true))
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true), t0)
end
if !ispdep
p = remake_buffer(prob, parameter_values(prob), keys(p), values(p))
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false))
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false), t0)
end

varmap = merge(u0, p)
Expand All @@ -693,7 +716,8 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
remake_buffer(prob, parameter_values(prob), keys(p), values(p))
end

function updated_u0_p(prob, u0, p; interpret_symbolicmap = true, use_defaults = false)
function updated_u0_p(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope it's here

prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false)
if u0 === missing && p === missing
return state_values(prob), parameter_values(prob)
end
Expand All @@ -712,7 +736,7 @@ function updated_u0_p(prob, u0, p; interpret_symbolicmap = true, use_defaults =
return (u0 === missing ? state_values(prob) : u0),
(p === missing ? parameter_values(prob) : p)
end
return _updated_u0_p_internal(prob, u0, p; interpret_symbolicmap, use_defaults)
return _updated_u0_p_internal(prob, u0, p, t0; interpret_symbolicmap, use_defaults)
end

# overloaded in MTK to intercept symbolic remake
Expand Down
Loading
Loading