@@ -1172,6 +1172,20 @@ function float_type_from_varmap(varmap, floatT = Bool)
11721172 return float (floatT)
11731173end
11741174
1175+ """
1176+ $(TYPEDSIGNATURES)
1177+
1178+ Calculate the floating point type to use from the given `varmap` by looking at variables
1179+ with a constant value. `u0Type` takes priority if it is a real-valued array type.
1180+ """
1181+ function calculate_float_type (varmap, u0Type:: Type , floatT = Bool)
1182+ if u0Type <: AbstractArray && eltype (u0Type) <: Real && eltype (u0Type) != Union{}
1183+ return float (eltype (u0Type))
1184+ else
1185+ return float_type_from_varmap (varmap, floatT)
1186+ end
1187+ end
1188+
11751189"""
11761190 $(TYPEDSIGNATURES)
11771191
@@ -1188,6 +1202,41 @@ function calculate_resid_prototype(N::Int, u0, p)
11881202 return zeros (u0ElType, N)
11891203end
11901204
1205+ """
1206+ $(TYPEDSIGNATURES)
1207+
1208+ Given the user-provided value of `u0_constructor`, the container type of user-provided
1209+ `op`, the desired floating point type and whether a symbolic `u0` is allowed, return the
1210+ updated `u0_constructor`.
1211+ """
1212+ function get_u0_constructor (u0_constructor, u0Type:: Type , floatT:: Type , symbolic_u0:: Bool )
1213+ u0_constructor === identity || return u0_constructor
1214+ u0Type <: StaticArray || return u0_constructor
1215+ return function (vals)
1216+ elT = if symbolic_u0 && any (x -> symbolic_type (x) != NotSymbolic (), vals)
1217+ nothing
1218+ else
1219+ floatT
1220+ end
1221+ SymbolicUtils. Code. create_array (u0Type, elT, Val (1 ), Val (length (vals)), vals... )
1222+ end
1223+ end
1224+
1225+ """
1226+ $(TYPEDSIGNATURES)
1227+
1228+ Given the user-provided value of `p_constructor`, the container type of user-provided `op`,
1229+ ans the desired floating point type, return the updated `p_constructor`.
1230+ """
1231+ function get_p_constructor (p_constructor, pType:: Type , floatT:: Type )
1232+ p_constructor === identity || return p_constructor
1233+ pType <: StaticArray || return p_constructor
1234+ return function (vals)
1235+ SymbolicUtils. Code. create_array (
1236+ pType, floatT, Val (ndims (vals)), Val (size (vals)), vals... )
1237+ end
1238+ end
1239+
11911240"""
11921241 $(TYPEDSIGNATURES)
11931242
@@ -1254,26 +1303,15 @@ function process_SciMLProblem(
12541303 missing_unknowns, missing_pars = build_operating_point! (sys, op,
12551304 u0map, pmap, defs, dvs, ps)
12561305
1257- floatT = Bool
1258- if u0Type <: AbstractArray && eltype (u0Type) <: Real && eltype (u0Type) != Union{}
1259- floatT = float (eltype (u0Type))
1260- else
1261- floatT = float_type_from_varmap (op, floatT)
1262- end
1263-
1306+ floatT = calculate_float_type (op, u0Type)
12641307 u0_eltype = something (u0_eltype, floatT)
12651308
12661309 if ! is_time_dependent (sys) || is_initializesystem (sys)
12671310 add_observed_equations! (op, obs)
12681311 end
1269- if u0_constructor === identity && u0Type <: StaticArray
1270- u0_constructor = vals -> SymbolicUtils. Code. create_array (
1271- u0Type, floatT, Val (1 ), Val (length (vals)), vals... )
1272- end
1273- if p_constructor === identity && pType <: StaticArray
1274- p_constructor = vals -> SymbolicUtils. Code. create_array (
1275- pType, floatT, Val (1 ), Val (length (vals)), vals... )
1276- end
1312+
1313+ u0_constructor = get_u0_constructor (u0_constructor, u0Type, u0_eltype, symbolic_u0)
1314+ p_constructor = get_p_constructor (p_constructor, pType, floatT)
12771315
12781316 if build_initializeprob
12791317 kws = maybe_build_initialization_problem (
0 commit comments