@@ -324,6 +324,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
324324 split_idxs = nothing ,
325325 initializeprob = nothing ,
326326 initializeprobmap = nothing ,
327+ initializeprob_updatep! = nothing ,
327328 kwargs... ) where {iip, specialize}
328329 if ! iscomplete (sys)
329330 error (" A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`" )
@@ -506,7 +507,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
506507 sparsity = sparsity ? jacobian_sparsity (sys) : nothing ,
507508 analytic = analytic,
508509 initializeprob = initializeprob,
509- initializeprobmap = initializeprobmap)
510+ initializeprobmap = initializeprobmap,
511+ initializeprob_updatep! = initializeprob_updatep!)
510512end
511513
512514"""
@@ -538,6 +540,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
538540 checkbounds = false ,
539541 initializeprob = nothing ,
540542 initializeprobmap = nothing ,
543+ initializeprob_updatep! = nothing ,
541544 kwargs... ) where {iip}
542545 if ! iscomplete (sys)
543546 error (" A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`" )
@@ -611,7 +614,8 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
611614 jac_prototype = jac_prototype,
612615 observed = observedfun,
613616 initializeprob = initializeprob,
614- initializeprobmap = initializeprobmap)
617+ initializeprobmap = initializeprobmap,
618+ initializeprob_updatep! = initializeprob_updatep!)
615619end
616620
617621function DiffEqBase. DDEFunction (sys:: AbstractODESystem , args... ; kwargs... )
@@ -862,7 +866,6 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
862866 varmap = canonicalize_varmap (varmap)
863867 varlist = collect (map (unwrap, dvs))
864868 missingvars = setdiff (varlist, collect (keys (varmap)))
865-
866869 # Append zeros to the variables which are determined by the initialization system
867870 # This essentially bypasses the check for if initial conditions are defined for DAEs
868871 # since they will be checked in the initialization problem's construction
@@ -873,11 +876,14 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
873876 parammap = Dict (unwrap (k) => v for (k, v) in todict (parammap))
874877 elseif parammap isa AbstractArray
875878 if isempty (parammap)
876- parammap = SciMLBase . NullParameters ()
879+ parammap = Dict ()
877880 else
878881 parammap = Dict (unwrap .(parameters (sys)) .=> parammap)
879882 end
883+ elseif parammap === nothing || parammap isa SciMLBase. NullParameters
884+ parammap = Dict ()
880885 end
886+ missingpars = setdiff (parameters (sys), keys (parammap))
881887
882888 if has_discrete_subsystems (sys) && get_discrete_subsystems (sys) != = nothing
883889 clockedparammap = Dict ()
@@ -886,7 +892,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
886892 v = unwrap (v)
887893 is_discrete_domain (v) || continue
888894 op = operation (v)
889- if ! isa (op, Symbolics. Operator) && parammap != SciMLBase . NullParameters ( ) &&
895+ if ! isa (op, Symbolics. Operator) && ! isempty (parammap ) &&
890896 haskey (parammap, v)
891897 error (" Initial conditions for discrete variables must be for the past state of the unknown. Instead of providing the condition for $v , provide the condition for $(Shift (iv, - 1 )(v)) ." )
892898 end
@@ -909,7 +915,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
909915 # TODO : make it work with clocks
910916 # ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
911917 if sys isa ODESystem && build_initializeprob &&
912- (implicit_dae || ! isempty (missingvars)) &&
918+ (implicit_dae || ! isempty (missingvars) || ! isempty (missingpars) ) &&
913919 all (isequal (Continuous ()), ci. var_domain) &&
914920 ModelingToolkit. get_tearing_state (sys) != = nothing &&
915921 t != = nothing
@@ -921,15 +927,43 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
921927 end
922928 initializeprob = ModelingToolkit. InitializationProblem (
923929 sys, t, u0map, parammap; guesses, warn_initialize_determined)
924- initializeprobmap = getu (initializeprob, unknowns (sys))
925-
930+ unks = unknowns (sys)
931+ initializeprobmap = isempty (unks) ? (_... ) -> nothing :
932+ getu (initializeprob, unknowns (sys))
933+ if any (p -> is_variable (initializeprob, p) || is_observed (initializeprob, p),
934+ parameters (sys))
935+ punknowns = [p
936+ for p in parameters (sys)
937+ if is_variable (initializeprob, p) ||
938+ is_observed (initializeprob, p)]
939+ initializeprob_updatep! = let getter = getu (initializeprob, tovar .(punknowns)),
940+ setter = setp (sys, punknowns)
941+
942+ function (ps, initsol)
943+ setter (ps, getter (initsol))
944+ end
945+ end
946+ else
947+ punknowns = []
948+ initializeprob_updatep! = nothing
949+ end
926950 zerovars = Dict (setdiff (unknowns (sys), keys (defaults (sys))) .=> 0.0 )
951+ zeropars = Dict ()
952+ for p in punknowns
953+ zeropars[p] = if Symbolics. isarraysymbolic (p)
954+ collect (unwrap .(zero (p)))
955+ else
956+ unwrap (zero (p))
957+ end
958+ end
927959 trueinit = collect (merge (zerovars, eltype (u0map) <: Pair ? todict (u0map) : u0map))
928960 u0map isa StaticArraysCore. StaticArray &&
929961 (trueinit = SVector {length(trueinit)} (trueinit))
930962 else
931963 initializeprob = nothing
932964 initializeprobmap = nothing
965+ initializeprob_updatep! = nothing
966+ zeropars = Dict ()
933967 trueinit = u0map
934968 end
935969
@@ -940,7 +974,12 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
940974 parammap == SciMLBase. NullParameters () && isempty (defs)
941975 nothing
942976 else
943- MTKParameters (sys, parammap, trueinit)
977+ if parammap === nothing || parammap == SciMLBase. NullParameters ()
978+ parammap = Dict ()
979+ else
980+ parammap = todict (parammap)
981+ end
982+ MTKParameters (sys, merge (parammap, zeropars), trueinit)
944983 end
945984 else
946985 u0, p, defs = get_u0_p (sys,
@@ -975,6 +1014,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
9751014 sparse = sparse, eval_expression = eval_expression,
9761015 initializeprob = initializeprob,
9771016 initializeprobmap = initializeprobmap,
1017+ initializeprob_updatep! = initializeprob_updatep!,
9781018 kwargs... )
9791019 implicit_dae ? (f, du0, u0, p) : (f, u0, p)
9801020end
@@ -1602,13 +1642,15 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
16021642 if ! iscomplete (sys)
16031643 error (" A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`" )
16041644 end
1645+ parammap = parammap isa SciMLBase. NullParameters ? Dict () : todict (parammap)
16051646 if isempty (u0map) && get_initializesystem (sys) != = nothing
16061647 isys = get_initializesystem (sys)
16071648 elseif isempty (u0map) && get_initializesystem (sys) === nothing
1608- isys = structural_simplify (generate_initializesystem (sys); fully_determined = false )
1649+ isys = structural_simplify (
1650+ generate_initializesystem (sys; pmap = parammap); fully_determined = false )
16091651 else
16101652 isys = structural_simplify (
1611- generate_initializesystem (sys; u0map); fully_determined = false )
1653+ generate_initializesystem (sys; u0map, pmap = parammap ); fully_determined = false )
16121654 end
16131655
16141656 uninit = setdiff (unknowns (sys), [unknowns (isys); getfield .(observed (isys), :lhs )])
@@ -1628,10 +1670,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
16281670 if warn_initialize_determined && neqs < nunknown
16291671 @warn " Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false."
16301672 end
1631-
1632- parammap = parammap isa DiffEqBase. NullParameters || isempty (parammap) ?
1633- [get_iv (sys) => t] :
1634- merge (todict (parammap), Dict (get_iv (sys) => t))
1673+ parammap[get_iv (sys)] = t
16351674 if isempty (u0map)
16361675 u0map = Dict ()
16371676 end
0 commit comments