@@ -357,7 +357,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
357
357
analytic = nothing ,
358
358
split_idxs = nothing ,
359
359
initializeprob = nothing ,
360
+ update_initializeprob! = nothing ,
360
361
initializeprobmap = nothing ,
362
+ initializeprobpmap = nothing ,
361
363
kwargs... ) where {iip, specialize}
362
364
if ! iscomplete (sys)
363
365
error (" A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`" )
@@ -459,7 +461,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
459
461
sparsity = sparsity ? jacobian_sparsity (sys) : nothing ,
460
462
analytic = analytic,
461
463
initializeprob = initializeprob,
462
- initializeprobmap = initializeprobmap)
464
+ update_initializeprob! = update_initializeprob!,
465
+ initializeprobmap = initializeprobmap,
466
+ initializeprobpmap = initializeprobpmap)
463
467
end
464
468
465
469
"""
@@ -789,6 +793,45 @@ function get_u0(
789
793
return u0, defs
790
794
end
791
795
796
+ struct GetUpdatedMTKParameters{G, S}
797
+ # `getu` functor which gets parameters that are unknowns during initialization
798
+ getpunknowns:: G
799
+ # `setu` functor which returns a modified MTKParameters using those parameters
800
+ setpunknowns:: S
801
+ end
802
+
803
+ function (f:: GetUpdatedMTKParameters )(prob, initializesol)
804
+ mtkp = copy (parameter_values (prob))
805
+ f. setpunknowns (mtkp, f. getpunknowns (initializesol))
806
+ mtkp
807
+ end
808
+
809
+ struct UpdateInitializeprob{G, S}
810
+ # `getu` functor which gets all values from prob
811
+ getvals:: G
812
+ # `setu` functor which updates initializeprob with values
813
+ setvals:: S
814
+ end
815
+
816
+ function (f:: UpdateInitializeprob )(initializeprob, prob)
817
+ f. setvals (initializeprob, f. getvals (prob))
818
+ end
819
+
820
+ function get_temporary_value (p)
821
+ stype = symtype (unwrap (p))
822
+ return if stype == Real
823
+ zero (Float64)
824
+ elseif stype <: AbstractArray{Real}
825
+ zeros (Float64, size (p))
826
+ elseif stype <: Real
827
+ zero (stype)
828
+ elseif stype <: AbstractArray
829
+ zeros (eltype (stype), size (p))
830
+ else
831
+ error (" Nonnumeric parameter $p with symtype $stype cannot be solved for during initialization" )
832
+ end
833
+ end
834
+
792
835
function process_DEProblem (constructor, sys:: AbstractODESystem , u0map, parammap;
793
836
implicit_dae = false , du0map = nothing ,
794
837
version = nothing , tgrad = false ,
@@ -829,18 +872,38 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
829
872
end
830
873
831
874
if eltype (parammap) <: Pair
832
- parammap = Dict (unwrap (k) => v for (k, v) in todict ( parammap) )
875
+ parammap = Dict {Any, Any} (unwrap (k) => v for (k, v) in parammap)
833
876
elseif parammap isa AbstractArray
834
877
if isempty (parammap)
835
878
parammap = SciMLBase. NullParameters ()
836
879
else
837
- parammap = Dict (unwrap .(parameters (sys)) .=> parammap)
880
+ parammap = Dict {Any, Any} (unwrap .(parameters (sys)) .=> parammap)
838
881
end
839
882
end
840
-
883
+ defs = defaults (sys)
884
+ if has_guesses (sys)
885
+ guesses = merge (
886
+ ModelingToolkit. guesses (sys), isempty (guesses) ? Dict () : todict (guesses))
887
+ solvablepars = [p
888
+ for p in parameters (sys)
889
+ if is_parameter_solvable (p, parammap, defs, guesses)]
890
+
891
+ pvarmap = if parammap === nothing || parammap == SciMLBase. NullParameters () || ! (eltype (parammap) <: Pair ) && isempty (parammap)
892
+ defs
893
+ else
894
+ merge (defs, todict (parammap))
895
+ end
896
+ setparobserved = filter (keys (pvarmap)) do var
897
+ has_parameter_dependency_with_lhs (sys, var)
898
+ end
899
+ else
900
+ solvablepars = ()
901
+ setparobserved = ()
902
+ end
841
903
# ModelingToolkit.get_tearing_state(sys) !== nothing => Requires structural_simplify first
842
904
if sys isa ODESystem && build_initializeprob &&
843
- (((implicit_dae || ! isempty (missingvars) || ! isempty (setobserved)) &&
905
+ (((implicit_dae || ! isempty (missingvars) || ! isempty (solvablepars) ||
906
+ ! isempty (setobserved) || ! isempty (setparobserved)) &&
844
907
ModelingToolkit. get_tearing_state (sys) != = nothing ) ||
845
908
! isempty (initialization_equations (sys))) && t != = nothing
846
909
if eltype (u0map) <: Number
@@ -854,14 +917,32 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
854
917
sys, t, u0map, parammap; guesses, warn_initialize_determined,
855
918
initialization_eqs, eval_expression, eval_module, fully_determined, check_units)
856
919
initializeprobmap = getu (initializeprob, unknowns (sys))
920
+ punknowns = [p
921
+ for p in all_variable_symbols (initializeprob) if is_parameter (sys, p)]
922
+ getpunknowns = getu (initializeprob, punknowns)
923
+ setpunknowns = setp (sys, punknowns)
924
+ initializeprobpmap = GetUpdatedMTKParameters (getpunknowns, setpunknowns)
925
+ reqd_syms = parameter_symbols (initializeprob)
926
+ update_initializeprob! = UpdateInitializeprob (
927
+ getu (sys, reqd_syms), setu (initializeprob, reqd_syms))
857
928
858
929
zerovars = Dict (setdiff (unknowns (sys), keys (defaults (sys))) .=> 0.0 )
930
+ if parammap isa SciMLBase. NullParameters
931
+ parammap = Dict ()
932
+ end
933
+ for p in punknowns
934
+ p = unwrap (p)
935
+ stype = symtype (p)
936
+ parammap[p] = get_temporary_value (p)
937
+ end
859
938
trueinit = collect (merge (zerovars, eltype (u0map) <: Pair ? todict (u0map) : u0map))
860
939
u0map isa StaticArraysCore. StaticArray &&
861
940
(trueinit = SVector {length(trueinit)} (trueinit))
862
941
else
863
942
initializeprob = nothing
943
+ update_initializeprob! = nothing
864
944
initializeprobmap = nothing
945
+ initializeprobpmap = nothing
865
946
trueinit = u0map
866
947
end
867
948
@@ -909,7 +990,9 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
909
990
sparse = sparse, eval_expression = eval_expression,
910
991
eval_module = eval_module,
911
992
initializeprob = initializeprob,
993
+ update_initializeprob! = update_initializeprob!,
912
994
initializeprobmap = initializeprobmap,
995
+ initializeprobpmap = initializeprobpmap,
913
996
kwargs... )
914
997
implicit_dae ? (f, du0, u0, p) : (f, u0, p)
915
998
end
@@ -1471,10 +1554,12 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
1471
1554
isys = get_initializesystem (sys; initialization_eqs, check_units)
1472
1555
elseif isempty (u0map) && get_initializesystem (sys) === nothing
1473
1556
isys = structural_simplify (
1474
- generate_initializesystem (sys; initialization_eqs, check_units); fully_determined)
1557
+ generate_initializesystem (
1558
+ sys; initialization_eqs, check_units, pmap = parammap); fully_determined)
1475
1559
else
1476
1560
isys = structural_simplify (
1477
- generate_initializesystem (sys; u0map, initialization_eqs, check_units); fully_determined)
1561
+ generate_initializesystem (
1562
+ sys; u0map, initialization_eqs, check_units, pmap = parammap); fully_determined)
1478
1563
end
1479
1564
1480
1565
uninit = setdiff (unknowns (sys), [unknowns (isys); getfield .(observed (isys), :lhs )])
@@ -1498,14 +1583,15 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem,
1498
1583
parammap = parammap isa DiffEqBase. NullParameters || isempty (parammap) ?
1499
1584
[get_iv (sys) => t] :
1500
1585
merge (todict (parammap), Dict (get_iv (sys) => t))
1586
+ parammap = Dict (k => v for (k, v) in parammap if v != = missing )
1501
1587
if isempty (u0map)
1502
1588
u0map = Dict ()
1503
1589
end
1504
1590
if isempty (guesses)
1505
1591
guesses = Dict ()
1506
1592
end
1507
1593
1508
- u0map = merge (todict (guesses), todict (u0map))
1594
+ u0map = merge (ModelingToolkit . guesses (sys), todict (guesses), todict (u0map))
1509
1595
if neqs == nunknown
1510
1596
NonlinearProblem (isys, u0map, parammap; kwargs... )
1511
1597
else
0 commit comments