Skip to content

Commit 3719c68

Browse files
refactor: modularize process_SciMLProblem a bit more
1 parent 31d9cdb commit 3719c68

File tree

1 file changed

+53
-15
lines changed

1 file changed

+53
-15
lines changed

src/systems/problem_utils.jl

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,6 +1192,20 @@ function float_type_from_varmap(varmap, floatT = Bool)
11921192
return float(floatT)
11931193
end
11941194

1195+
"""
1196+
$(TYPEDSIGNATURES)
1197+
1198+
Calculate the floating point type to use from the given `varmap` by looking at variables
1199+
with a constant value. `u0Type` takes priority if it is a real-valued array type.
1200+
"""
1201+
function calculate_float_type(varmap, u0Type::Type, floatT = Bool)
1202+
if u0Type <: AbstractArray && eltype(u0Type) <: Real && eltype(u0Type) != Union{}
1203+
return float(eltype(u0Type))
1204+
else
1205+
return float_type_from_varmap(varmap, floatT)
1206+
end
1207+
end
1208+
11951209
"""
11961210
$(TYPEDSIGNATURES)
11971211
@@ -1208,6 +1222,41 @@ function calculate_resid_prototype(N::Int, u0, p)
12081222
return zeros(u0ElType, N)
12091223
end
12101224

1225+
"""
1226+
$(TYPEDSIGNATURES)
1227+
1228+
Given the user-provided value of `u0_constructor`, the container type of user-provided
1229+
`op`, the desired floating point type and whether a symbolic `u0` is allowed, return the
1230+
updated `u0_constructor`.
1231+
"""
1232+
function get_u0_constructor(u0_constructor, u0Type::Type, floatT::Type, symbolic_u0::Bool)
1233+
u0_constructor === identity || return u0_constructor
1234+
u0Type <: StaticArray || return u0_constructor
1235+
return function (vals)
1236+
elT = if symbolic_u0 && any(x -> symbolic_type(x) != NotSymbolic(), vals)
1237+
nothing
1238+
else
1239+
floatT
1240+
end
1241+
SymbolicUtils.Code.create_array(u0Type, elT, Val(1), Val(length(vals)), vals...)
1242+
end
1243+
end
1244+
1245+
"""
1246+
$(TYPEDSIGNATURES)
1247+
1248+
Given the user-provided value of `p_constructor`, the container type of user-provided `op`,
1249+
ans the desired floating point type, return the updated `p_constructor`.
1250+
"""
1251+
function get_p_constructor(p_constructor, pType::Type, floatT::Type)
1252+
p_constructor === identity || return p_constructor
1253+
pType <: StaticArray || return p_constructor
1254+
return function (vals)
1255+
SymbolicUtils.Code.create_array(
1256+
pType, floatT, Val(ndims(vals)), Val(size(vals)), vals...)
1257+
end
1258+
end
1259+
12111260
"""
12121261
$(TYPEDSIGNATURES)
12131262
@@ -1274,26 +1323,15 @@ function process_SciMLProblem(
12741323
missing_unknowns, missing_pars = build_operating_point!(sys, op,
12751324
u0map, pmap, defs, dvs, ps)
12761325

1277-
floatT = Bool
1278-
if u0Type <: AbstractArray && eltype(u0Type) <: Real && eltype(u0Type) != Union{}
1279-
floatT = float(eltype(u0Type))
1280-
else
1281-
floatT = float_type_from_varmap(op, floatT)
1282-
end
1283-
1326+
floatT = calculate_float_type(op, u0Type)
12841327
u0_eltype = something(u0_eltype, floatT)
12851328

12861329
if !is_time_dependent(sys) || is_initializesystem(sys)
12871330
add_observed_equations!(op, obs)
12881331
end
1289-
if u0_constructor === identity && u0Type <: StaticArray
1290-
u0_constructor = vals -> SymbolicUtils.Code.create_array(
1291-
u0Type, floatT, Val(1), Val(length(vals)), vals...)
1292-
end
1293-
if p_constructor === identity && pType <: StaticArray
1294-
p_constructor = vals -> SymbolicUtils.Code.create_array(
1295-
pType, floatT, Val(1), Val(length(vals)), vals...)
1296-
end
1332+
1333+
u0_constructor = get_u0_constructor(u0_constructor, u0Type, u0_eltype, symbolic_u0)
1334+
p_constructor = get_p_constructor(p_constructor, pType, floatT)
12971335

12981336
if build_initializeprob
12991337
kws = maybe_build_initialization_problem(

0 commit comments

Comments
 (0)