@@ -654,17 +654,40 @@ end
654654
655655function SciMLBase. late_binding_update_u0_p (
656656 prob, sys:: AbstractSystem , u0, p, t0, newu0, newp)
657+ supports_initialization (sys) || return newu0, newp
657658 u0 === missing && return newu0, (p === missing ? copy (newp) : newp)
658- eltype (u0) <: Pair || return newu0, (p === missing ? copy (newp) : newp)
659+ # non-symbolic u0 updates initials...
660+ if ! (eltype (u0) <: Pair )
661+ # if `p` is not provided or is symbolic
662+ p === missing || eltype (p) <: Pair || return newu0, newp
663+ newu0 === nothing && return newu0, newp
664+ all (is_parameter (sys, Initial (x)) for x in unknowns (sys)) || return newu0, newp
665+ newp = p === missing ? copy (newp) : newp
666+ initials, repack, alias = SciMLStructures. canonicalize (
667+ SciMLStructures. Initials (), newp)
668+ if eltype (initials) != eltype (newu0)
669+ initials = DiffEqBase. promote_u0 (initials, newu0, t0)
670+ newp = repack (initials)
671+ end
672+ if length (newu0) != length (unknowns (sys))
673+ throw (ArgumentError (" Expected `newu0` to be of same length as unknowns ($(length (unknowns (sys))) ). Got $(typeof (newu0)) of length $(length (newu0)) " ))
674+ end
675+ setp (sys, Initial .(unknowns (sys)))(newp, newu0)
676+ return newu0, newp
677+ end
659678
660679 newp = p === missing ? copy (newp) : newp
661680 newu0 = DiffEqBase. promote_u0 (newu0, newp, t0)
662681 tunables, repack, alias = SciMLStructures. canonicalize (SciMLStructures. Tunable (), newp)
663- tunables = DiffEqBase. promote_u0 (tunables, newu0, t0)
664- newp = repack (tunables)
682+ if eltype (tunables) != eltype (newu0)
683+ tunables = DiffEqBase. promote_u0 (tunables, newu0, t0)
684+ newp = repack (tunables)
685+ end
665686 initials, repack, alias = SciMLStructures. canonicalize (SciMLStructures. Initials (), newp)
666- initials = DiffEqBase. promote_u0 (initials, newu0, t0)
667- newp = repack (initials)
687+ if eltype (initials) != eltype (newu0)
688+ initials = DiffEqBase. promote_u0 (initials, newu0, t0)
689+ newp = repack (initials)
690+ end
668691
669692 allsyms = all_symbols (sys)
670693 for (k, v) in u0
0 commit comments