Skip to content

Commit 4d52b0c

Browse files
refactor: modularize process_SciMLProblem a bit more
1 parent bb5c5c3 commit 4d52b0c

File tree

1 file changed

+53
-15
lines changed

1 file changed

+53
-15
lines changed

src/systems/problem_utils.jl

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1172,6 +1172,20 @@ function float_type_from_varmap(varmap, floatT = Bool)
11721172
return float(floatT)
11731173
end
11741174

1175+
"""
1176+
$(TYPEDSIGNATURES)
1177+
1178+
Calculate the floating point type to use from the given `varmap` by looking at variables
1179+
with a constant value. `u0Type` takes priority if it is a real-valued array type.
1180+
"""
1181+
function calculate_float_type(varmap, u0Type::Type, floatT = Bool)
1182+
if u0Type <: AbstractArray && eltype(u0Type) <: Real && eltype(u0Type) != Union{}
1183+
return float(eltype(u0Type))
1184+
else
1185+
return float_type_from_varmap(varmap, floatT)
1186+
end
1187+
end
1188+
11751189
"""
11761190
$(TYPEDSIGNATURES)
11771191
@@ -1188,6 +1202,41 @@ function calculate_resid_prototype(N::Int, u0, p)
11881202
return zeros(u0ElType, N)
11891203
end
11901204

1205+
"""
1206+
$(TYPEDSIGNATURES)
1207+
1208+
Given the user-provided value of `u0_constructor`, the container type of user-provided
1209+
`op`, the desired floating point type and whether a symbolic `u0` is allowed, return the
1210+
updated `u0_constructor`.
1211+
"""
1212+
function get_u0_constructor(u0_constructor, u0Type::Type, floatT::Type, symbolic_u0::Bool)
1213+
u0_constructor === identity || return u0_constructor
1214+
u0Type <: StaticArray || return u0_constructor
1215+
return function (vals)
1216+
elT = if symbolic_u0 && any(x -> symbolic_type(x) != NotSymbolic(), vals)
1217+
nothing
1218+
else
1219+
floatT
1220+
end
1221+
SymbolicUtils.Code.create_array(u0Type, elT, Val(1), Val(length(vals)), vals...)
1222+
end
1223+
end
1224+
1225+
"""
1226+
$(TYPEDSIGNATURES)
1227+
1228+
Given the user-provided value of `p_constructor`, the container type of user-provided `op`,
1229+
ans the desired floating point type, return the updated `p_constructor`.
1230+
"""
1231+
function get_p_constructor(p_constructor, pType::Type, floatT::Type)
1232+
p_constructor === identity || return p_constructor
1233+
pType <: StaticArray || return p_constructor
1234+
return function (vals)
1235+
SymbolicUtils.Code.create_array(
1236+
pType, floatT, Val(ndims(vals)), Val(size(vals)), vals...)
1237+
end
1238+
end
1239+
11911240
"""
11921241
$(TYPEDSIGNATURES)
11931242
@@ -1254,26 +1303,15 @@ function process_SciMLProblem(
12541303
missing_unknowns, missing_pars = build_operating_point!(sys, op,
12551304
u0map, pmap, defs, dvs, ps)
12561305

1257-
floatT = Bool
1258-
if u0Type <: AbstractArray && eltype(u0Type) <: Real && eltype(u0Type) != Union{}
1259-
floatT = float(eltype(u0Type))
1260-
else
1261-
floatT = float_type_from_varmap(op, floatT)
1262-
end
1263-
1306+
floatT = calculate_float_type(op, u0Type)
12641307
u0_eltype = something(u0_eltype, floatT)
12651308

12661309
if !is_time_dependent(sys) || is_initializesystem(sys)
12671310
add_observed_equations!(op, obs)
12681311
end
1269-
if u0_constructor === identity && u0Type <: StaticArray
1270-
u0_constructor = vals -> SymbolicUtils.Code.create_array(
1271-
u0Type, floatT, Val(1), Val(length(vals)), vals...)
1272-
end
1273-
if p_constructor === identity && pType <: StaticArray
1274-
p_constructor = vals -> SymbolicUtils.Code.create_array(
1275-
pType, floatT, Val(1), Val(length(vals)), vals...)
1276-
end
1312+
1313+
u0_constructor = get_u0_constructor(u0_constructor, u0Type, u0_eltype, symbolic_u0)
1314+
p_constructor = get_p_constructor(p_constructor, pType, floatT)
12771315

12781316
if build_initializeprob
12791317
kws = maybe_build_initialization_problem(

0 commit comments

Comments
 (0)