Skip to content

Commit 991e559

Browse files
fix: use tspan in updated_u0_p for initial values dependent on time
1 parent 02d4a08 commit 991e559

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

src/remake.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ function remake(prob::ODEProblem; f = missing,
118118
tspan = prob.tspan
119119
end
120120

121-
newu0, newp = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
121+
newu0, newp = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults)
122122

123123
iip = isinplace(prob)
124124

@@ -214,7 +214,7 @@ function remake(prob::BVProblem{uType, tType, iip, nlls}; f = missing, bc = miss
214214
tspan = prob.tspan
215215
end
216216

217-
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
217+
u0, p = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults)
218218

219219
if problem_type === missing
220220
problem_type = prob.problem_type
@@ -280,7 +280,7 @@ function remake(prob::SDEProblem;
280280
tspan = prob.tspan
281281
end
282282

283-
u0, p = updated_u0_p(prob, u0, p; interpret_symbolicmap, use_defaults)
283+
u0, p = updated_u0_p(prob, u0, p, tspan[1]; interpret_symbolicmap, use_defaults)
284284

285285
if noise === missing
286286
noise = prob.noise
@@ -496,35 +496,35 @@ anydict(d) = Dict{Any, Any}(d)
496496
anydict() = Dict{Any, Any}()
497497

498498
function _updated_u0_p_internal(
499-
prob, ::Missing, ::Missing; interpret_symbolicmap = true, use_defaults = false)
499+
prob, ::Missing, ::Missing, t0; interpret_symbolicmap = true, use_defaults = false)
500500
return state_values(prob), parameter_values(prob)
501501
end
502502
function _updated_u0_p_internal(
503-
prob, ::Missing, p; interpret_symbolicmap = true, use_defaults = false)
503+
prob, ::Missing, p, t0; interpret_symbolicmap = true, use_defaults = false)
504504
u0 = state_values(prob)
505505

506506
if p isa AbstractArray && isempty(p)
507507
return _updated_u0_p_internal(
508-
prob, u0, parameter_values(prob); interpret_symbolicmap)
508+
prob, u0, parameter_values(prob), t0; interpret_symbolicmap)
509509
end
510510
eltype(p) <: Pair && interpret_symbolicmap || return u0, p
511511
defs = default_values(prob)
512512
p = fill_p(prob, anydict(p); defs, use_defaults)
513-
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true))
513+
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true), t0)
514514
end
515515

516516
function _updated_u0_p_internal(
517-
prob, u0, ::Missing; interpret_symbolicmap = true, use_defaults = false)
517+
prob, u0, ::Missing, t0; interpret_symbolicmap = true, use_defaults = false)
518518
p = parameter_values(prob)
519519

520520
eltype(u0) <: Pair || return u0, p
521521
defs = default_values(prob)
522522
u0 = fill_u0(prob, anydict(u0); defs, use_defaults)
523-
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false))
523+
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false), t0)
524524
end
525525

526526
function _updated_u0_p_internal(
527-
prob, u0, p; interpret_symbolicmap = true, use_defaults = false)
527+
prob, u0, p, t0; interpret_symbolicmap = true, use_defaults = false)
528528
isu0symbolic = eltype(u0) <: Pair
529529
ispsymbolic = eltype(p) <: Pair && interpret_symbolicmap
530530

@@ -538,7 +538,7 @@ function _updated_u0_p_internal(
538538
if ispsymbolic
539539
p = fill_p(prob, anydict(p); defs, use_defaults)
540540
end
541-
return _updated_u0_p_symmap(prob, u0, Val(isu0symbolic), p, Val(ispsymbolic))
541+
return _updated_u0_p_symmap(prob, u0, Val(isu0symbolic), p, Val(ispsymbolic), t0)
542542
end
543543

544544
function fill_u0(prob, u0; defs = nothing, use_defaults = false)
@@ -629,7 +629,7 @@ function fill_p(prob, p; defs = nothing, use_defaults = false)
629629
return newvals
630630
end
631631

632-
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false})
632+
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false}, t0)
633633
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
634634
isdep || return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
635635

@@ -642,13 +642,13 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{false})
642642
# FIXME: need to provide `u` since the observed function expects it.
643643
# This is sort of an implicit dependency on MTK. The values of `u` won't actually be
644644
# used, since any state symbols in the expression were substituted out earlier.
645-
temp_state = ProblemState(; u = state_values(prob), p = p)
645+
temp_state = ProblemState(; u = state_values(prob), p = p, t = t0)
646646
u0 = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
647647
for (k, v) in u0)
648648
return remake_buffer(prob, state_values(prob), keys(u0), values(u0)), p
649649
end
650650

651-
function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true})
651+
function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true}, t0)
652652
isdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
653653
isdep || return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
654654

@@ -661,13 +661,13 @@ function _updated_u0_p_symmap(prob, u0, ::Val{false}, p, ::Val{true})
661661
# FIXME: need to provide `p` since the observed function expects an `MTKParameters`
662662
# this is sort of an implicit dependency on MTK. The values of `p` won't actually be
663663
# used, since any parameter symbols in the expression were substituted out earlier.
664-
temp_state = ProblemState(; u = u0, p = parameter_values(prob))
664+
temp_state = ProblemState(; u = u0, p = parameter_values(prob), t = t0)
665665
p = anydict(k => symbolic_type(v) === NotSymbolic() ? v : getu(prob, v)(temp_state)
666666
for (k, v) in p)
667667
return u0, remake_buffer(prob, parameter_values(prob), keys(p), values(p))
668668
end
669669

670-
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
670+
function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true}, t0)
671671
isu0dep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in u0)
672672
ispdep = any(symbolic_type(v) !== NotSymbolic() for (_, v) in p)
673673

@@ -677,11 +677,11 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
677677
end
678678
if !isu0dep
679679
u0 = remake_buffer(prob, state_values(prob), keys(u0), values(u0))
680-
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true))
680+
return _updated_u0_p_symmap(prob, u0, Val(false), p, Val(true), t0)
681681
end
682682
if !ispdep
683683
p = remake_buffer(prob, parameter_values(prob), keys(p), values(p))
684-
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false))
684+
return _updated_u0_p_symmap(prob, u0, Val(true), p, Val(false), t0)
685685
end
686686

687687
varmap = merge(u0, p)
@@ -693,7 +693,7 @@ function _updated_u0_p_symmap(prob, u0, ::Val{true}, p, ::Val{true})
693693
remake_buffer(prob, parameter_values(prob), keys(p), values(p))
694694
end
695695

696-
function updated_u0_p(prob, u0, p; interpret_symbolicmap = true, use_defaults = false)
696+
function updated_u0_p(prob, u0, p, t0 = nothing; interpret_symbolicmap = true, use_defaults = false)
697697
if u0 === missing && p === missing
698698
return state_values(prob), parameter_values(prob)
699699
end
@@ -712,7 +712,7 @@ function updated_u0_p(prob, u0, p; interpret_symbolicmap = true, use_defaults =
712712
return (u0 === missing ? state_values(prob) : u0),
713713
(p === missing ? parameter_values(prob) : p)
714714
end
715-
return _updated_u0_p_internal(prob, u0, p; interpret_symbolicmap, use_defaults)
715+
return _updated_u0_p_internal(prob, u0, p, t0; interpret_symbolicmap, use_defaults)
716716
end
717717

718718
# overloaded in MTK to intercept symbolic remake

0 commit comments

Comments
 (0)