Skip to content

Commit 204987b

Browse files
fix: call remake_initialization_data when explicit f provided to remake
1 parent 957dd58 commit 204987b

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

src/remake.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,8 @@ function remake(prob::ODEProblem; f = missing,
225225

226226
if build_initializeprob
227227
if f !== missing && has_initialization_data(f)
228-
initialization_data = f.initialization_data
228+
initialization_data = remake_initialization_data(
229+
prob.f.sys, f, u0, tspan[1], p, newu0, newp)
229230
else
230231
initialization_data = remake_initialization_data(
231232
prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp)
@@ -413,7 +414,8 @@ function remake(prob::SDEProblem;
413414

414415
if build_initializeprob
415416
if f !== missing && has_initialization_data(f)
416-
initialization_data = f.initialization_data
417+
initialization_data = remake_initialization_data(
418+
prob.f.sys, f, u0, tspan[1], p, newu0, newp)
417419
else
418420
initialization_data = remake_initialization_data(
419421
prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp)
@@ -481,7 +483,8 @@ function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing,
481483

482484
if build_initializeprob
483485
if f !== missing && has_initialization_data(f)
484-
initialization_data = f.initialization_data
486+
initialization_data = remake_initialization_data(
487+
prob.f.sys, f, u0, tspan[1], p, newu0, newp)
485488
else
486489
initialization_data = remake_initialization_data(
487490
prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp)
@@ -561,7 +564,8 @@ function remake(prob::SDDEProblem;
561564

562565
if build_initializeprob
563566
if f !== missing && has_initialization_data(f)
564-
initialization_data = f.initialization_data
567+
initialization_data = remake_initialization_data(
568+
prob.f.sys, f, u0, tspan[1], p, newu0, newp)
565569
else
566570
initialization_data = remake_initialization_data(
567571
prob.f.sys, prob.f, u0, tspan[1], p, newu0, newp)
@@ -711,7 +715,8 @@ function remake(prob::NonlinearProblem;
711715

712716
if build_initializeprob
713717
if f !== missing && has_initialization_data(f)
714-
initialization_data = f.initialization_data
718+
initialization_data = remake_initialization_data(
719+
prob.f.sys, f, u0, nothing, p, newu0, newp)
715720
else
716721
initialization_data = remake_initialization_data(
717722
prob.f.sys, prob.f, u0, nothing, p, newu0, newp)
@@ -765,7 +770,8 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p
765770

766771
if build_initializeprob
767772
if f !== missing && has_initialization_data(f)
768-
initialization_data = f.initialization_data
773+
initialization_data = remake_initialization_data(
774+
prob.f.sys, f, u0, nothing, p, newu0, newp)
769775
else
770776
initialization_data = remake_initialization_data(
771777
prob.f.sys, prob.f, u0, nothing, p, newu0, newp)

test/downstream/modelingtoolkit_remake.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using ModelingToolkit: t_nounits as t, D_nounits as D
44
using OrdinaryDiffEq
55
using Optimization
66
using OptimizationOptimJL
7+
using ForwardDiff
8+
using SciMLStructures
79

810
probs = []
911
syss = []
@@ -406,3 +408,18 @@ end
406408
prob = ODEProblem(sys, [:x => 1.0], (0.0, 1.0), [p => 1.0])
407409
@test_nowarn remake(prob; u0 = [:y => 1.0, :x => nothing])
408410
end
411+
412+
@testset "`initialization_data` u0 and p are promoted with explicit `f`" begin
413+
@variables x(t) [guess = 1.0] y(t) [guess = 1.0]
414+
@parameters p q
415+
@mtkbuild sys = ODESystem([D(x) ~ x, (x - p) ^ 2 + (y - q) ^ 3 ~ 0], t)
416+
prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0, q => 2.0])
417+
@test prob.f.initialization_data !== nothing
418+
buf, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), prob.p)
419+
newps = repack(ForwardDiff.Dual.(buf))
420+
prob2 = @test_nowarn remake(prob; f = prob.f, u0 = ForwardDiff.Dual.(prob.u0), p = newps)
421+
@test prob2.f.initialization_data !== nothing
422+
initprob = prob2.f.initialization_data.initializeprob
423+
@test eltype(initprob.u0) <: ForwardDiff.Dual
424+
@test eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), initprob.p)[1]) <: ForwardDiff.Dual
425+
end

0 commit comments

Comments
 (0)