Skip to content

Commit 6023efb

Browse files
refactor: add float_type_from_varmap
1 parent a8e2495 commit 6023efb

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

src/systems/problem_utils.jl

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,26 @@ function maybe_build_initialization_problem(
713713
initializeprobpmap))
714714
end
715715

716+
"""
717+
$(TYPEDSIGNATURES)
718+
719+
Calculate the floating point type to use from the given `varmap` by looking at variables
720+
with a constant value.
721+
"""
722+
function float_type_from_varmap(varmap, floatT = Bool)
723+
for (k, v) in varmap
724+
symbolic_type(v) == NotSymbolic() || continue
725+
is_array_of_symbolics(v) && continue
726+
727+
if v isa AbstractArray
728+
floatT = promote_type(floatT, eltype(v))
729+
elseif v isa Real
730+
floatT = promote_type(floatT, typeof(v))
731+
end
732+
end
733+
return float(floatT)
734+
end
735+
716736
"""
717737
$(TYPEDSIGNATURES)
718738
@@ -818,20 +838,10 @@ function process_SciMLProblem(
818838

819839
floatT = Bool
820840
if u0Type <: AbstractArray && eltype(u0Type) <: Real
821-
floatT = eltype(u0Type)
841+
floatT = float(eltype(u0Type))
822842
else
823-
for (k, v) in op
824-
symbolic_type(v) == NotSymbolic() || continue
825-
is_array_of_symbolics(v) && continue
826-
827-
if v isa AbstractArray
828-
floatT = promote_type(floatT, eltype(v))
829-
elseif v isa Real
830-
floatT = promote_type(floatT, typeof(v))
831-
end
832-
end
843+
floatT = float_type_from_varmap(op, floatT)
833844
end
834-
floatT = float(floatT)
835845

836846
if !is_time_dependent(sys) || is_initializesystem(sys)
837847
add_observed_equations!(u0map, obs)

0 commit comments

Comments
 (0)