@@ -1192,6 +1192,20 @@ function float_type_from_varmap(varmap, floatT = Bool)
1192
1192
return float (floatT)
1193
1193
end
1194
1194
1195
+ """
1196
+ $(TYPEDSIGNATURES)
1197
+
1198
+ Calculate the floating point type to use from the given `varmap` by looking at variables
1199
+ with a constant value. `u0Type` takes priority if it is a real-valued array type.
1200
+ """
1201
+ function calculate_float_type (varmap, u0Type:: Type , floatT = Bool)
1202
+ if u0Type <: AbstractArray && eltype (u0Type) <: Real && eltype (u0Type) != Union{}
1203
+ return float (eltype (u0Type))
1204
+ else
1205
+ return float_type_from_varmap (varmap, floatT)
1206
+ end
1207
+ end
1208
+
1195
1209
"""
1196
1210
$(TYPEDSIGNATURES)
1197
1211
@@ -1208,6 +1222,41 @@ function calculate_resid_prototype(N::Int, u0, p)
1208
1222
return zeros (u0ElType, N)
1209
1223
end
1210
1224
1225
+ """
1226
+ $(TYPEDSIGNATURES)
1227
+
1228
+ Given the user-provided value of `u0_constructor`, the container type of user-provided
1229
+ `op`, the desired floating point type and whether a symbolic `u0` is allowed, return the
1230
+ updated `u0_constructor`.
1231
+ """
1232
+ function get_u0_constructor (u0_constructor, u0Type:: Type , floatT:: Type , symbolic_u0:: Bool )
1233
+ u0_constructor === identity || return u0_constructor
1234
+ u0Type <: StaticArray || return u0_constructor
1235
+ return function (vals)
1236
+ elT = if symbolic_u0 && any (x -> symbolic_type (x) != NotSymbolic (), vals)
1237
+ nothing
1238
+ else
1239
+ floatT
1240
+ end
1241
+ SymbolicUtils. Code. create_array (u0Type, elT, Val (1 ), Val (length (vals)), vals... )
1242
+ end
1243
+ end
1244
+
1245
+ """
1246
+ $(TYPEDSIGNATURES)
1247
+
1248
+ Given the user-provided value of `p_constructor`, the container type of user-provided `op`,
1249
+ ans the desired floating point type, return the updated `p_constructor`.
1250
+ """
1251
+ function get_p_constructor (p_constructor, pType:: Type , floatT:: Type )
1252
+ p_constructor === identity || return p_constructor
1253
+ pType <: StaticArray || return p_constructor
1254
+ return function (vals)
1255
+ SymbolicUtils. Code. create_array (
1256
+ pType, floatT, Val (ndims (vals)), Val (size (vals)), vals... )
1257
+ end
1258
+ end
1259
+
1211
1260
"""
1212
1261
$(TYPEDSIGNATURES)
1213
1262
@@ -1274,26 +1323,15 @@ function process_SciMLProblem(
1274
1323
missing_unknowns, missing_pars = build_operating_point! (sys, op,
1275
1324
u0map, pmap, defs, dvs, ps)
1276
1325
1277
- floatT = Bool
1278
- if u0Type <: AbstractArray && eltype (u0Type) <: Real && eltype (u0Type) != Union{}
1279
- floatT = float (eltype (u0Type))
1280
- else
1281
- floatT = float_type_from_varmap (op, floatT)
1282
- end
1283
-
1326
+ floatT = calculate_float_type (op, u0Type)
1284
1327
u0_eltype = something (u0_eltype, floatT)
1285
1328
1286
1329
if ! is_time_dependent (sys) || is_initializesystem (sys)
1287
1330
add_observed_equations! (op, obs)
1288
1331
end
1289
- if u0_constructor === identity && u0Type <: StaticArray
1290
- u0_constructor = vals -> SymbolicUtils. Code. create_array (
1291
- u0Type, floatT, Val (1 ), Val (length (vals)), vals... )
1292
- end
1293
- if p_constructor === identity && pType <: StaticArray
1294
- p_constructor = vals -> SymbolicUtils. Code. create_array (
1295
- pType, floatT, Val (1 ), Val (length (vals)), vals... )
1296
- end
1332
+
1333
+ u0_constructor = get_u0_constructor (u0_constructor, u0Type, u0_eltype, symbolic_u0)
1334
+ p_constructor = get_p_constructor (p_constructor, pType, floatT)
1297
1335
1298
1336
if build_initializeprob
1299
1337
kws = maybe_build_initialization_problem (
0 commit comments