Skip to content

Commit f9cdd16

Browse files
fix: respect floatT in maybe_build_initialization_problem
1 parent 662217b commit f9cdd16

File tree

1 file changed

+29
-9
lines changed

1 file changed

+29
-9
lines changed

src/systems/problem_utils.jl

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,22 @@ function maybe_build_initialization_problem(
641641

642642
initializeprob = ModelingToolkit.InitializationProblem{true, SciMLBase.FullSpecialize}(
643643
sys, t, u0map, pmap; guesses, kwargs...)
644+
if state_values(initializeprob) !== nothing
645+
initializeprob = remake(initializeprob; u0 = floatT.(state_values(initializeprob)))
646+
end
647+
initp = parameter_values(initializeprob)
648+
if is_split(sys)
649+
buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), initp)
650+
initp = repack(floatT.(buffer))
651+
buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Initials(), initp)
652+
initp = repack(floatT.(buffer))
653+
elseif initp isa AbstractArray
654+
initp′ = similar(initp, floatT)
655+
copyto!(initp′, initp)
656+
initp = initp′
657+
end
658+
initializeprob = remake(initializeprob; p = initp)
659+
644660
meta = get_metadata(initializeprob.f.sys)
645661

646662
if is_time_dependent(sys)
@@ -800,15 +816,19 @@ function process_SciMLProblem(
800816
u0map, pmap, defs, cmap, dvs, ps)
801817

802818
floatT = Bool
803-
for (k, v) in op
804-
symbolic_type(v) == NotSymbolic() || continue
805-
is_array_of_symbolics(v) && continue
806-
807-
if v isa AbstractArray
808-
isconcretetype(eltype(v)) || continue
809-
floatT = promote_type(floatT, eltype(v))
810-
elseif v isa Real && isconcretetype(v)
811-
floatT = promote_type(floatT, typeof(v))
819+
if u0Type <: AbstractArray && isconcretetype(eltype(u0Type)) && eltype(u0Type) <: Real
820+
floatT = eltype(u0Type)
821+
else
822+
for (k, v) in op
823+
symbolic_type(v) == NotSymbolic() || continue
824+
is_array_of_symbolics(v) && continue
825+
826+
if v isa AbstractArray
827+
isconcretetype(eltype(v)) || continue
828+
floatT = promote_type(floatT, eltype(v))
829+
elseif v isa Real && isconcretetype(v)
830+
floatT = promote_type(floatT, typeof(v))
831+
end
812832
end
813833
end
814834
floatT = float(floatT)

0 commit comments

Comments
 (0)