Skip to content

Commit 1f669d9

Browse files
Merge pull request #3252 from AayushSabharwal/as/initprob-resid-prototype
fix: recalculate resid_prototype in remake_initialization_data
2 parents 4626fe7 + c578da1 commit 1f669d9

File tree

3 files changed

+38
-8
lines changed

3 files changed

+38
-8
lines changed

src/systems/nonlinear/initializesystem.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,16 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p,
260260
newp = remake_buffer(
261261
oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals)
262262
end
263-
initprob = remake(oldinitprob; u0 = newu0, p = newp)
263+
if oldinitprob.f.resid_prototype === nothing
264+
newf = oldinitprob.f
265+
else
266+
newf = NonlinearFunction{
267+
SciMLBase.isinplace(oldinitprob.f), SciMLBase.specialization(oldinitprob.f)}(
268+
oldinitprob.f;
269+
resid_prototype = calculate_resid_prototype(
270+
length(oldinitprob.f.resid_prototype), newu0, newp))
271+
end
272+
initprob = remake(oldinitprob; f = newf, u0 = newu0, p = newp)
264273
return SciMLBase.OverrideInitData(initprob, odefn.update_initializeprob!,
265274
odefn.initializeprobmap, odefn.initializeprobpmap)
266275
end

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,16 @@ function hessian_sparsity(sys::NonlinearSystem)
283283
unknowns(sys)) for eq in equations(sys)]
284284
end
285285

286+
function calculate_resid_prototype(N, u0, p)
287+
u0ElType = u0 === nothing ? Float64 : eltype(u0)
288+
if SciMLStructures.isscimlstructure(p)
289+
u0ElType = promote_type(
290+
eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]),
291+
u0ElType)
292+
end
293+
return zeros(u0ElType, N)
294+
end
295+
286296
"""
287297
```julia
288298
SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
@@ -337,13 +347,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
337347
if length(dvs) == length(equations(sys))
338348
resid_prototype = nothing
339349
else
340-
u0ElType = u0 === nothing ? Float64 : eltype(u0)
341-
if SciMLStructures.isscimlstructure(p)
342-
u0ElType = promote_type(
343-
eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]),
344-
u0ElType)
345-
end
346-
resid_prototype = zeros(u0ElType, length(equations(sys)))
350+
resid_prototype = calculate_resid_prototype(length(equations(sys)), u0, p)
347351
end
348352

349353
NonlinearFunction{iip}(f,

test/initializationsystem.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,3 +1032,20 @@ end
10321032
@test prob3.f.initialization_data !== nothing
10331033
@test init(prob3)[x] 0.5
10341034
end
1035+
1036+
@testset "Issue#3246: type promotion with parameter dependent initialization_eqs" begin
1037+
@variables x(t)=1 y(t)=1
1038+
@parameters a = 1
1039+
@named sys = ODESystem([D(x) ~ 0, D(y) ~ x + a], t; initialization_eqs = [y ~ a])
1040+
1041+
ssys = structural_simplify(sys)
1042+
prob = ODEProblem(ssys, [], (0, 1), [])
1043+
1044+
@test SciMLBase.successful_retcode(solve(prob))
1045+
1046+
seta = setsym_oop(prob, [a])
1047+
(newu0, newp) = seta(prob, ForwardDiff.Dual{ForwardDiff.Tag{:tag, Float64}}.([1.0], 1))
1048+
newprob = remake(prob, u0 = newu0, p = newp)
1049+
1050+
@test SciMLBase.successful_retcode(solve(newprob))
1051+
end

0 commit comments

Comments
 (0)