Skip to content

Commit 8836abe

Browse files
fix: improve type promotion in remake_initializeprob
1 parent 96faa98 commit 8836abe

File tree

2 files changed

+77
-4
lines changed

2 files changed

+77
-4
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,15 +173,54 @@ function is_parameter_solvable(p, pmap, defs, guesses)
173173
# either (missing is a default or was passed to the ODEProblem) or (nothing was passed to
174174
# the ODEProblem and it has a default and a guess)
175175
return ((_val1 === missing || _val2 === missing) ||
176-
(_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
176+
(_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
177177
end
178178

179179
function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
180-
if (u0 === missing || !(eltype(u0) <: Pair) || isempty(u0)) &&
181-
(p === missing || !(eltype(p) <: Pair) || isempty(p))
180+
if u0 === missing && p === missing
182181
return odefn.initializeprob, odefn.update_initializeprob!, odefn.initializeprobmap,
183182
odefn.initializeprobpmap
184183
end
184+
if !(eltype(u0) <: Pair) && !(eltype(p) <: Pair)
185+
oldinitprob = odefn.initializeprob
186+
if oldinitprob === nothing || !SciMLBase.has_sys(oldinitprob.f) ||
187+
!(oldinitprob.f.sys isa NonlinearSystem)
188+
return oldinitprob, odefn.update_initializeprob!, odefn.initializeprobmap,
189+
odefn.initializeprobpmap
190+
end
191+
pidxs = ParameterIndex[]
192+
pvals = []
193+
u0idxs = Int[]
194+
u0vals = []
195+
for sym in variable_symbols(oldinitprob)
196+
if is_variable(sys, sym)
197+
u0 !== missing || continue
198+
idx = variable_index(oldinitprob, sym)
199+
push!(u0idxs, idx)
200+
push!(u0vals, eltype(u0)(state_values(oldinitprob, idx)))
201+
else
202+
p !== missing || continue
203+
idx = variable_index(oldinitprob, sym)
204+
push!(u0idxs, idx)
205+
push!(u0vals, typeof(getp(sys, sym)(p))(state_values(oldinitprob, idx)))
206+
end
207+
end
208+
if p !== missing
209+
for sym in parameter_symbols(oldinitprob)
210+
push!(pidxs, parameter_index(oldinitprob, sym))
211+
if isequal(sym, get_iv(sys))
212+
push!(pvals, t0)
213+
else
214+
push!(pvals, getp(sys, sym)(p))
215+
end
216+
end
217+
end
218+
newu0 = remake_buffer(oldinitprob.f.sys, state_values(oldinitprob), u0idxs, u0vals)
219+
newp = remake_buffer(oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals)
220+
initprob = remake(oldinitprob; u0 = newu0, p = newp)
221+
return initprob, odefn.update_initializeprob!, odefn.initializeprobmap,
222+
odefn.initializeprobpmap
223+
end
185224
if u0 === missing || isempty(u0)
186225
u0 = Dict()
187226
elseif !(eltype(u0) <: Pair)

test/initializationsystem.jl

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using ModelingToolkit, OrdinaryDiffEq, NonlinearSolve, Test
2-
using SymbolicIndexingInterface
2+
using ForwardDiff
3+
using SymbolicIndexingInterface, SciMLStructures
4+
using SciMLStructures: Tunable
35
using ModelingToolkit: t_nounits as t, D_nounits as D
46

57
@parameters g
@@ -746,3 +748,35 @@ end
746748
prob5 = remake(prob)
747749
@test init(prob, Tsit5()).ps[p] 2.0
748750
end
751+
752+
@testset "`remake` changes initialization problem types" begin
753+
@variables x(t) y(t) z(t)
754+
@parameters p q
755+
@mtkbuild sys = ODESystem(
756+
[D(x) ~ x * p + y * q, y^2 * q + q^2 * x ~ 0, z * p - p^2 * x * z ~ 0],
757+
t; guesses = [x => 0.0, y => 0.0, z => 0.0, p => 0.0, q => 0.0])
758+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0, q => missing])
759+
@test is_variable(prob.f.initializeprob, q)
760+
ps = prob.p
761+
newps = SciMLStructures.replace(Tunable(), ps, ForwardDiff.Dual.(ps.tunable))
762+
prob2 = remake(prob; p = newps)
763+
@test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual
764+
@test eltype(prob2.f.initializeprob.p.tunable) <: ForwardDiff.Dual
765+
@test prob2.f.initializeprob.u0 prob.f.initializeprob.u0
766+
767+
prob2 = remake(prob; u0 = ForwardDiff.Dual.(prob.u0))
768+
@test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual
769+
@test eltype(prob2.f.initializeprob.p.tunable) <: Float64
770+
@test prob2.f.initializeprob.u0 prob.f.initializeprob.u0
771+
772+
prob2 = remake(prob; u0 = ForwardDiff.Dual.(prob.u0), p = newps)
773+
@test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual
774+
@test eltype(prob2.f.initializeprob.p.tunable) <: ForwardDiff.Dual
775+
@test prob2.f.initializeprob.u0 prob.f.initializeprob.u0
776+
777+
prob2 = remake(prob; u0 = [x => ForwardDiff.Dual(1.0)],
778+
p = [p => ForwardDiff.Dual(1.0), q => missing])
779+
@test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual
780+
@test eltype(prob2.f.initializeprob.p.tunable) <: ForwardDiff.Dual
781+
@test prob2.f.initializeprob.u0 prob.f.initializeprob.u0
782+
end

0 commit comments

Comments
 (0)