@@ -247,25 +247,6 @@ function recursive_unwrap(x::AbstractDict)
247247 return anydict (unwrap (k) => recursive_unwrap (v) for (k, v) in x)
248248end
249249
250- """
251- $(TYPEDSIGNATURES)
252-
253- Return the appropriate zero value for a symbolic variable representing a number or array of
254- numbers. Sized array symbolics return a zero-filled array of matching size. Unsized array
255- symbolics return an empty array of the appropriate `eltype`.
256- """
257- function zero_var (x:: Symbolic{T} ) where {V <: Number , T <: Union{V, AbstractArray{V}} }
258- if Symbolics. isarraysymbolic (x)
259- if is_sized_array_symbolic (x)
260- return zeros (eltype (T), size (x))
261- else
262- return T[]
263- end
264- else
265- return zero (T)
266- end
267- end
268-
269250"""
270251 $(TYPEDSIGNATURES)
271252
@@ -362,7 +343,7 @@ Keyword arguments:
362343- `is_initializeprob, guesses`: Used to determine whether the system is missing guesses.
363344"""
364345function better_varmap_to_vars (varmap:: AbstractDict , vars:: Vector ;
365- tofloat = true , container_type = Array,
346+ tofloat = true , container_type = Array, floatT = Nothing,
366347 toterm = default_toterm, promotetoconcrete = nothing , check = true ,
367348 allow_symbolic = false , is_initializeprob = false )
368349 isempty (vars) && return nothing
@@ -385,6 +366,9 @@ function better_varmap_to_vars(varmap::AbstractDict, vars::Vector;
385366 is_initializeprob ? throw (MissingGuessError (missingsyms, missingvals)) :
386367 throw (UnexpectedSymbolicValueInVarmap (missingsyms[1 ], missingvals[1 ]))
387368 end
369+ if tofloat && ! (floatT == Nothing)
370+ vals = floatT .(vals)
371+ end
388372 end
389373
390374 if container_type <: Union{AbstractDict, Tuple, Nothing, SciMLBase.NullParameters}
@@ -533,12 +517,12 @@ function (f::UpdateInitializeprob)(initializeprob, prob)
533517 f. setvals (initializeprob, f. getvals (prob))
534518end
535519
536- function get_temporary_value (p)
520+ function get_temporary_value (p, floatT = Float64 )
537521 stype = symtype (unwrap (p))
538522 return if stype == Real
539- zero (Float64 )
523+ zero (floatT )
540524 elseif stype <: AbstractArray{Real}
541- zeros (Float64 , size (p))
525+ zeros (floatT , size (p))
542526 elseif stype <: Real
543527 zero (stype)
544528 elseif stype <: AbstractArray
@@ -648,15 +632,32 @@ All other keyword arguments are forwarded to `InitializationProblem`.
648632"""
649633function maybe_build_initialization_problem (
650634 sys:: AbstractSystem , op:: AbstractDict , u0map, pmap, t, defs,
651- guesses, missing_unknowns; implicit_dae = false , u0_constructor = identity, kwargs... )
635+ guesses, missing_unknowns; implicit_dae = false ,
636+ u0_constructor = identity, floatT = Float64, kwargs... )
652637 guesses = merge (ModelingToolkit. guesses (sys), todict (guesses))
653638
654639 if t === nothing && is_time_dependent (sys)
655- t = 0.0
640+ t = zero (floatT)
656641 end
657642
658643 initializeprob = ModelingToolkit. InitializationProblem {true, SciMLBase.FullSpecialize} (
659644 sys, t, u0map, pmap; guesses, kwargs... )
645+ if state_values (initializeprob) != = nothing
646+ initializeprob = remake (initializeprob; u0 = floatT .(state_values (initializeprob)))
647+ end
648+ initp = parameter_values (initializeprob)
649+ if is_split (sys)
650+ buffer, repack, _ = SciMLStructures. canonicalize (SciMLStructures. Tunable (), initp)
651+ initp = repack (floatT .(buffer))
652+ buffer, repack, _ = SciMLStructures. canonicalize (SciMLStructures. Initials (), initp)
653+ initp = repack (floatT .(buffer))
654+ elseif initp isa AbstractArray
655+ initp′ = similar (initp, floatT)
656+ copyto! (initp′, initp)
657+ initp = initp′
658+ end
659+ initializeprob = remake (initializeprob; p = initp)
660+
660661 meta = get_metadata (initializeprob. f. sys)
661662
662663 if is_time_dependent (sys)
@@ -692,7 +693,7 @@ function maybe_build_initialization_problem(
692693 get (op, p, missing ) === missing || continue
693694 p = unwrap (p)
694695 stype = symtype (p)
695- op[p] = get_temporary_value (p)
696+ op[p] = get_temporary_value (p, floatT )
696697 if iscall (p) && operation (p) === getindex
697698 arrp = arguments (p)[1 ]
698699 op[arrp] = collect (arrp)
@@ -701,7 +702,7 @@ function maybe_build_initialization_problem(
701702
702703 if is_time_dependent (sys)
703704 for v in missing_unknowns
704- op[v] = zero_var (v )
705+ op[v] = get_temporary_value (v, floatT )
705706 end
706707 empty! (missing_unknowns)
707708 end
@@ -712,6 +713,26 @@ function maybe_build_initialization_problem(
712713 initializeprobpmap))
713714end
714715
716+ """
717+ $(TYPEDSIGNATURES)
718+
719+ Calculate the floating point type to use from the given `varmap` by looking at variables
720+ with a constant value.
721+ """
722+ function float_type_from_varmap (varmap, floatT = Bool)
723+ for (k, v) in varmap
724+ symbolic_type (v) == NotSymbolic () || continue
725+ is_array_of_symbolics (v) && continue
726+
727+ if v isa AbstractArray
728+ floatT = promote_type (floatT, eltype (v))
729+ elseif v isa Real
730+ floatT = promote_type (floatT, typeof (v))
731+ end
732+ end
733+ return float (floatT)
734+ end
735+
715736"""
716737 $(TYPEDSIGNATURES)
717738
@@ -815,12 +836,19 @@ function process_SciMLProblem(
815836 op, missing_unknowns, missing_pars = build_operating_point! (sys,
816837 u0map, pmap, defs, cmap, dvs, ps)
817838
839+ floatT = Bool
840+ if u0Type <: AbstractArray && eltype (u0Type) <: Real
841+ floatT = float (eltype (u0Type))
842+ else
843+ floatT = float_type_from_varmap (op, floatT)
844+ end
845+
818846 if ! is_time_dependent (sys) || is_initializesystem (sys)
819847 add_observed_equations! (u0map, obs)
820848 end
821849 if u0_constructor === identity && u0Type <: StaticArray
822850 u0_constructor = vals -> SymbolicUtils. Code. create_array (
823- u0Type, eltype (vals) , Val (1 ), Val (length (vals)), vals... )
851+ u0Type, floatT , Val (1 ), Val (length (vals)), vals... )
824852 end
825853 if build_initializeprob
826854 kws = maybe_build_initialization_problem (
@@ -830,7 +858,7 @@ function process_SciMLProblem(
830858 warn_cyclic_dependency, check_units = check_initialization_units,
831859 circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc,
832860 force_time_independent = force_initialization_time_independent, algebraic_only, allow_incomplete,
833- u0_constructor)
861+ u0_constructor, floatT )
834862
835863 kwargs = merge (kwargs, kws)
836864 end
@@ -858,7 +886,7 @@ function process_SciMLProblem(
858886 evaluate_varmap! (op, dvs; limit = substitution_limit)
859887
860888 u0 = better_varmap_to_vars (
861- op, dvs; tofloat,
889+ op, dvs; tofloat, floatT,
862890 container_type = u0Type, allow_symbolic = symbolic_u0, is_initializeprob)
863891
864892 if u0 != = nothing
@@ -882,7 +910,7 @@ function process_SciMLProblem(
882910 end
883911 evaluate_varmap! (op, ps; limit = substitution_limit)
884912 if is_split (sys)
885- p = MTKParameters (sys, op)
913+ p = MTKParameters (sys, op; floatT = floatT )
886914 else
887915 p = better_varmap_to_vars (op, ps; tofloat, container_type = pType)
888916 end
@@ -898,6 +926,16 @@ function process_SciMLProblem(
898926 du0 = nothing
899927 end
900928
929+ if build_initializeprob
930+ t0 = t
931+ if is_time_dependent (sys) && t0 === nothing
932+ t0 = zero (floatT)
933+ end
934+ initialization_data = SciMLBase. remake_initialization_data (
935+ kwargs. initialization_data, kwargs, u0, t0, p, u0, p)
936+ kwargs = merge (kwargs,)
937+ end
938+
901939 f = constructor (sys, dvs, ps, u0; p = p,
902940 eval_expression = eval_expression,
903941 eval_module = eval_module,
0 commit comments