Skip to content

Commit a16ea8c

Browse files
fix: improve type promotion in remake_initializeprob
1 parent b5c4894 commit a16ea8c

File tree

2 files changed

+76
-4
lines changed

2 files changed

+76
-4
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,15 +169,54 @@ function is_parameter_solvable(p, pmap, defs, guesses)
169169
# either (missing is a default or was passed to the ODEProblem) or (nothing was passed to
170170
# the ODEProblem and it has a default and a guess)
171171
return ((_val1 === missing || _val2 === missing) ||
172-
(_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
172+
(_val1 === nothing && _val2 !== nothing)) && _val3 !== nothing
173173
end
174174

175175
function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p)
176-
if (u0 === missing || !(eltype(u0) <: Pair) || isempty(u0)) &&
177-
(p === missing || !(eltype(p) <: Pair) || isempty(p))
176+
if u0 === missing && p === missing
178177
return odefn.initializeprob, odefn.update_initializeprob!, odefn.initializeprobmap,
179178
odefn.initializeprobpmap
180179
end
180+
if !(eltype(u0) <: Pair) && !(eltype(p) <: Pair)
181+
oldinitprob = odefn.initializeprob
182+
if oldinitprob === nothing || !SciMLBase.has_sys(oldinitprob.f) ||
183+
!(oldinitprob.f.sys isa NonlinearSystem)
184+
return oldinitprob, odefn.update_initializeprob!, odefn.initializeprobmap,
185+
odefn.initializeprobpmap
186+
end
187+
pidxs = ParameterIndex[]
188+
pvals = []
189+
u0idxs = Int[]
190+
u0vals = []
191+
for sym in variable_symbols(oldinitprob)
192+
if is_variable(sys, sym)
193+
u0 !== missing || continue
194+
idx = variable_index(oldinitprob, sym)
195+
push!(u0idxs, idx)
196+
push!(u0vals, eltype(u0)(state_values(oldinitprob, idx)))
197+
else
198+
p !== missing || continue
199+
idx = variable_index(oldinitprob, sym)
200+
push!(u0idxs, idx)
201+
push!(u0vals, typeof(getp(sys, sym)(p))(state_values(oldinitprob, idx)))
202+
end
203+
end
204+
if p !== missing
205+
for sym in parameter_symbols(oldinitprob)
206+
push!(pidxs, parameter_index(oldinitprob, sym))
207+
if isequal(sym, get_iv(sys))
208+
push!(pvals, t0)
209+
else
210+
push!(pvals, getp(sys, sym)(p))
211+
end
212+
end
213+
end
214+
newu0 = remake_buffer(oldinitprob.f.sys, state_values(oldinitprob), u0idxs, u0vals)
215+
newp = remake_buffer(oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals)
216+
initprob = remake(oldinitprob; u0 = newu0, p = newp)
217+
return initprob, odefn.update_initializeprob!, odefn.initializeprobmap,
218+
odefn.initializeprobpmap
219+
end
181220
if u0 === missing || isempty(u0)
182221
u0 = Dict()
183222
elseif !(eltype(u0) <: Pair)

test/initializationsystem.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using ModelingToolkit, OrdinaryDiffEq, NonlinearSolve, Test
2-
using SymbolicIndexingInterface
2+
using SymbolicIndexingInterface, SciMLStructures
3+
using SciMLStructures: Tunable
34
using ModelingToolkit: t_nounits as t, D_nounits as D
45

56
@parameters g
@@ -746,3 +747,35 @@ end
746747
prob5 = remake(prob)
747748
@test init(prob, Tsit5()).ps[p] 2.0
748749
end
750+
751+
@testset "`remake` changes initialization problem types" begin
752+
@variables x(t) y(t) z(t)
753+
@parameters p q
754+
@mtkbuild sys = ODESystem(
755+
[D(x) ~ x * p + y * q, y^2 * q + q^2 * x ~ 0, z * p - p^2 * x * z ~ 0],
756+
t; guesses = [x => 0.0, y => 0.0, z => 0.0, p => 0.0, q => 0.0])
757+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0, q => missing])
758+
@test is_variable(prob.f.initializeprob, q)
759+
ps = prob.p
760+
newps = SciMLStructures.replace(Tunable(), ps, ForwardDiff.Dual.(ps.tunable))
761+
prob2 = remake(prob; p = newps)
762+
@test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual
763+
@test eltype(prob2.f.initializeprob.p.tunable) <: ForwardDiff.Dual
764+
@test prob2.f.initializeprob.u0 prob.f.initializeprob.u0
765+
766+
prob2 = remake(prob; u0 = ForwardDiff.Dual.(prob.u0))
767+
@test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual
768+
@test eltype(prob2.f.initializeprob.p.tunable) <: Float64
769+
@test prob2.f.initializeprob.u0 prob.f.initializeprob.u0
770+
771+
prob2 = remake(prob; u0 = ForwardDiff.Dual.(prob.u0), p = newps)
772+
@test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual
773+
@test eltype(prob2.f.initializeprob.p.tunable) <: ForwardDiff.Dual
774+
@test prob2.f.initializeprob.u0 prob.f.initializeprob.u0
775+
776+
prob2 = remake(prob; u0 = [x => ForwardDiff.Dual(1.0)],
777+
p = [p => ForwardDiff.Dual(1.0), q => missing])
778+
@test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual
779+
@test eltype(prob2.f.initializeprob.p.tunable) <: ForwardDiff.Dual
780+
@test prob2.f.initializeprob.u0 prob.f.initializeprob.u0
781+
end

0 commit comments

Comments
 (0)