@@ -1192,6 +1192,20 @@ function float_type_from_varmap(varmap, floatT = Bool)
11921192 return float (floatT)
11931193end
11941194
1195+ """
1196+ $(TYPEDSIGNATURES)
1197+
1198+ Calculate the floating point type to use from the given `varmap` by looking at variables
1199+ with a constant value. `u0Type` takes priority if it is a real-valued array type.
1200+ """
1201+ function calculate_float_type (varmap, u0Type:: Type , floatT = Bool)
1202+ if u0Type <: AbstractArray && eltype (u0Type) <: Real && eltype (u0Type) != Union{}
1203+ return float (eltype (u0Type))
1204+ else
1205+ return float_type_from_varmap (varmap, floatT)
1206+ end
1207+ end
1208+
11951209"""
11961210 $(TYPEDSIGNATURES)
11971211
@@ -1208,6 +1222,41 @@ function calculate_resid_prototype(N::Int, u0, p)
12081222 return zeros (u0ElType, N)
12091223end
12101224
1225+ """
1226+ $(TYPEDSIGNATURES)
1227+
1228+ Given the user-provided value of `u0_constructor`, the container type of user-provided
1229+ `op`, the desired floating point type and whether a symbolic `u0` is allowed, return the
1230+ updated `u0_constructor`.
1231+ """
1232+ function get_u0_constructor (u0_constructor, u0Type:: Type , floatT:: Type , symbolic_u0:: Bool )
1233+ u0_constructor === identity || return u0_constructor
1234+ u0Type <: StaticArray || return u0_constructor
1235+ return function (vals)
1236+ elT = if symbolic_u0 && any (x -> symbolic_type (x) != NotSymbolic (), vals)
1237+ nothing
1238+ else
1239+ floatT
1240+ end
1241+ SymbolicUtils. Code. create_array (u0Type, elT, Val (1 ), Val (length (vals)), vals... )
1242+ end
1243+ end
1244+
1245+ """
1246+ $(TYPEDSIGNATURES)
1247+
1248+ Given the user-provided value of `p_constructor`, the container type of user-provided `op`,
1249+ ans the desired floating point type, return the updated `p_constructor`.
1250+ """
1251+ function get_p_constructor (p_constructor, pType:: Type , floatT:: Type )
1252+ p_constructor === identity || return p_constructor
1253+ pType <: StaticArray || return p_constructor
1254+ return function (vals)
1255+ SymbolicUtils. Code. create_array (
1256+ pType, floatT, Val (ndims (vals)), Val (size (vals)), vals... )
1257+ end
1258+ end
1259+
12111260"""
12121261 $(TYPEDSIGNATURES)
12131262
@@ -1274,26 +1323,15 @@ function process_SciMLProblem(
12741323 missing_unknowns, missing_pars = build_operating_point! (sys, op,
12751324 u0map, pmap, defs, dvs, ps)
12761325
1277- floatT = Bool
1278- if u0Type <: AbstractArray && eltype (u0Type) <: Real && eltype (u0Type) != Union{}
1279- floatT = float (eltype (u0Type))
1280- else
1281- floatT = float_type_from_varmap (op, floatT)
1282- end
1283-
1326+ floatT = calculate_float_type (op, u0Type)
12841327 u0_eltype = something (u0_eltype, floatT)
12851328
12861329 if ! is_time_dependent (sys) || is_initializesystem (sys)
12871330 add_observed_equations! (op, obs)
12881331 end
1289- if u0_constructor === identity && u0Type <: StaticArray
1290- u0_constructor = vals -> SymbolicUtils. Code. create_array (
1291- u0Type, floatT, Val (1 ), Val (length (vals)), vals... )
1292- end
1293- if p_constructor === identity && pType <: StaticArray
1294- p_constructor = vals -> SymbolicUtils. Code. create_array (
1295- pType, floatT, Val (1 ), Val (length (vals)), vals... )
1296- end
1332+
1333+ u0_constructor = get_u0_constructor (u0_constructor, u0Type, u0_eltype, symbolic_u0)
1334+ p_constructor = get_p_constructor (p_constructor, pType, floatT)
12971335
12981336 if build_initializeprob
12991337 kws = maybe_build_initialization_problem (
0 commit comments