Skip to content

Commit 5467cac

Browse files
fixup! fix: improve type promotion in remake_initializeprob
1 parent 1e29425 commit 5467cac

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

test/initializationsystem.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using ModelingToolkit, OrdinaryDiffEq, NonlinearSolve, Test
2-
using SymbolicIndexingInterface
2+
using SymbolicIndexingInterface, SciMLStructures
3+
using SciMLStructures: Tunable
34
using ModelingToolkit: t_nounits as t, D_nounits as D
45

56
@parameters g
@@ -737,3 +738,35 @@ end
737738
integ = init(prob, Tsit5())
738739
@test integ.ps[p] 2
739740
end
741+
742+
@testset "`remake` changes initialization problem types" begin
743+
@variables x(t) y(t) z(t)
744+
@parameters p q
745+
@mtkbuild sys = ODESystem(
746+
[D(x) ~ x * p + y * q, y^2 * q + q^2 * x ~ 0, z * p - p^2 * x * z ~ 0],
747+
t; guesses = [x => 0.0, y => 0.0, z => 0.0, p => 0.0, q => 0.0])
748+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0, q => missing])
749+
@test is_variable(prob.f.initializeprob, q)
750+
ps = prob.p
751+
newps = SciMLStructures.replace(Tunable(), ps, ForwardDiff.Dual.(ps.tunable))
752+
prob2 = remake(prob; p = newps)
753+
@test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual
754+
@test eltype(prob2.f.initializeprob.p.tunable) <: ForwardDiff.Dual
755+
@test prob2.f.initializeprob.u0 prob.f.initializeprob.u0
756+
757+
prob2 = remake(prob; u0 = ForwardDiff.Dual.(prob.u0))
758+
@test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual
759+
@test eltype(prob2.f.initializeprob.p.tunable) <: Float64
760+
@test prob2.f.initializeprob.u0 prob.f.initializeprob.u0
761+
762+
prob2 = remake(prob; u0 = ForwardDiff.Dual.(prob.u0), p = newps)
763+
@test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual
764+
@test eltype(prob2.f.initializeprob.p.tunable) <: ForwardDiff.Dual
765+
@test prob2.f.initializeprob.u0 prob.f.initializeprob.u0
766+
767+
prob2 = remake(prob; u0 = [x => ForwardDiff.Dual(1.0)],
768+
p = [p => ForwardDiff.Dual(1.0), q => missing])
769+
@test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual
770+
@test eltype(prob2.f.initializeprob.p.tunable) <: ForwardDiff.Dual
771+
@test prob2.f.initializeprob.u0 prob.f.initializeprob.u0
772+
end

0 commit comments

Comments
 (0)