Skip to content

Commit c01ffa5

Browse files
Merge pull request #911 from AayushSabharwal/as/remake-eager-init
fix: fix eager initialization in `remake`
2 parents d2d5e6f + a08e4dd commit c01ffa5

File tree

4 files changed

+52
-75
lines changed

4 files changed

+52
-75
lines changed

src/problems/problem_utils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,3 +180,6 @@ function Base.summary(io::IO, prob::AbstractPDEProblem)
180180
end
181181

182182
Base.copy(p::SciMLBase.NullParameters) = p
183+
184+
SymbolicIndexingInterface.is_time_dependent(::AbstractDEProblem) = true
185+
SymbolicIndexingInterface.is_time_dependent(::AbstractNonlinearProblem) = false

src/remake.jl

Lines changed: 37 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -257,18 +257,9 @@ function remake(prob::ODEProblem; f = missing,
257257
ODEProblem{iip}(f, newu0, tspan, newp, prob.problem_type; kwargs...)
258258
end
259259

260-
if lazy_initialization === nothing
261-
lazy_initialization = !is_trivial_initialization(initialization_data)
262-
end
263-
if initialization_data !== nothing && !lazy_initialization
264-
u0, p, _ = get_initial_values(
265-
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
266-
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
267-
u0 = nothing
268-
end
269-
@reset prob.u0 = u0
270-
@reset prob.p = p
271-
end
260+
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
261+
@reset prob.u0 = u0
262+
@reset prob.p = p
272263

273264
return prob
274265
end
@@ -453,18 +444,10 @@ function remake(prob::SDEProblem;
453444
else
454445
SDEProblem{iip}(f, newu0, tspan, newp; noise, noise_rate_prototype, seed, kwargs...)
455446
end
456-
if lazy_initialization === nothing
457-
lazy_initialization = !is_trivial_initialization(initialization_data)
458-
end
459-
if initialization_data !== nothing && !lazy_initialization
460-
u0, p, _ = get_initial_values(
461-
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
462-
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
463-
u0 = nothing
464-
end
465-
@reset prob.u0 = u0
466-
@reset prob.p = p
467-
end
447+
448+
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
449+
@reset prob.u0 = u0
450+
@reset prob.p = p
468451

469452
return prob
470453
end
@@ -520,18 +503,10 @@ function remake(prob::DDEProblem; f = missing, h = missing, u0 = missing,
520503
DDEProblem{iip}(f, newu0, h, tspan, newp; constant_lags, dependent_lags,
521504
order_discontinuity_t0, neutral, kwargs...)
522505
end
523-
if lazy_initialization === nothing
524-
lazy_initialization = !is_trivial_initialization(initialization_data)
525-
end
526-
if initialization_data !== nothing && !lazy_initialization
527-
u0, p, _ = get_initial_values(
528-
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
529-
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
530-
u0 = nothing
531-
end
532-
@reset prob.u0 = u0
533-
@reset prob.p = p
534-
end
506+
507+
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
508+
@reset prob.u0 = u0
509+
@reset prob.p = p
535510

536511
return prob
537512
end
@@ -619,18 +594,9 @@ function remake(prob::SDDEProblem;
619594
dependent_lags, order_discontinuity_t0, neutral, kwargs...)
620595
end
621596

622-
if lazy_initialization === nothing
623-
lazy_initialization = !is_trivial_initialization(initialization_data)
624-
end
625-
if initialization_data !== nothing && !lazy_initialization
626-
u0, p, _ = get_initial_values(
627-
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
628-
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
629-
u0 = nothing
630-
end
631-
@reset prob.u0 = u0
632-
@reset prob.p = p
633-
end
597+
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
598+
@reset prob.u0 = u0
599+
@reset prob.p = p
634600

635601
return prob
636602
end
@@ -741,18 +707,9 @@ function remake(prob::NonlinearProblem;
741707
problem_type = problem_type; kwargs...)
742708
end
743709

744-
if lazy_initialization === nothing
745-
lazy_initialization = !is_trivial_initialization(initialization_data)
746-
end
747-
if initialization_data !== nothing && !lazy_initialization
748-
u0, p, _ = get_initial_values(
749-
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
750-
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
751-
u0 = nothing
752-
end
753-
@reset prob.u0 = u0
754-
@reset prob.p = p
755-
end
710+
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
711+
@reset prob.u0 = u0
712+
@reset prob.p = p
756713

757714
return prob
758715
end
@@ -792,18 +749,9 @@ function remake(prob::NonlinearLeastSquaresProblem; f = missing, u0 = missing, p
792749
f, u0 = newu0, p = newp, kwargs...)
793750
end
794751

795-
if lazy_initialization === nothing
796-
lazy_initialization = !is_trivial_initialization(initialization_data)
797-
end
798-
if initialization_data !== nothing && !lazy_initialization
799-
u0, p, _ = get_initial_values(
800-
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
801-
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
802-
u0 = nothing
803-
end
804-
@reset prob.u0 = u0
805-
@reset prob.p = p
806-
end
752+
u0, p = maybe_eager_initialize_problem(prob, initialization_data, lazy_initialization)
753+
@reset prob.u0 = u0
754+
@reset prob.p = p
807755

808756
return prob
809757
end
@@ -1134,6 +1082,23 @@ function process_p_u0_symbolic(prob, p, u0)
11341082
end
11351083
end
11361084

1085+
function maybe_eager_initialize_problem(prob::AbstractSciMLProblem, initialization_data, lazy_initialization::Union{Nothing, Bool})
1086+
if lazy_initialization === nothing
1087+
lazy_initialization = !is_trivial_initialization(initialization_data)
1088+
end
1089+
if initialization_data !== nothing && !lazy_initialization && (!is_time_dependent(prob) || current_time(prob) !== nothing)
1090+
u0, p, _ = get_initial_values(
1091+
prob, prob, prob.f, OverrideInit(), Val(isinplace(prob)))
1092+
if u0 !== nothing && eltype(u0) == Any && isempty(u0)
1093+
u0 = nothing
1094+
end
1095+
else
1096+
u0 = state_values(prob)
1097+
p = parameter_values(prob)
1098+
end
1099+
return u0, p
1100+
end
1101+
11371102
function remake(thing::AbstractJumpProblem; kwargs...)
11381103
parameterless_type(thing)(remake(thing.prob; kwargs...))
11391104
end

test/initialization.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,10 @@ end
263263
@testset "Trivial initialization" begin
264264
initprob = NonlinearProblem(Returns(nothing), nothing, [1.0])
265265
update_initializeprob! = function (iprob, integ)
266-
iprob.p[1] = integ.u[1]
266+
# just to access the current time and use it as a number, so this errors
267+
# if run on a problem with `current_time(prob) === nothing`
268+
iprob.p[1] = current_time(integ) + 1
269+
iprob.p[1] = state_values(integ)[1]
267270
end
268271
initprobmap = function (nlsol)
269272
u1 = parameter_values(nlsol)[1]
@@ -284,6 +287,11 @@ end
284287
@test u0 [2.0, 2.0]
285288
@test p 0.0
286289
@test success
290+
291+
@testset "Doesn't run in `remake` if `tspan == (nothing, nothing)`" begin
292+
prob = ODEProblem(fn, [2.0, 0.0], (nothing, nothing), 0.0)
293+
@test_nowarn remake(prob)
294+
end
287295
end
288296
end
289297

test/remake_tests.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ u0 = [1.0; 2.0; 3.0]
1515
tspan = (0.0, 100.0)
1616
p = [10.0, 20.0, 30.0]
1717
sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
18+
indep_sys = SymbolCache([:x, :y, :z], [:a, :b, :c])
1819
fn = ODEFunction(lorenz!; sys)
1920
for T in containerTypes
2021
push!(probs, ODEProblem(fn, u0, tspan, T(p)))
@@ -64,7 +65,7 @@ function loss(x, p)
6465
return sum(du)
6566
end
6667

67-
fn = OptimizationFunction(loss; sys)
68+
fn = OptimizationFunction(loss; sys = indep_sys)
6869
for T in containerTypes
6970
push!(probs, OptimizationProblem(fn, u0, T(p)))
7071
end
@@ -73,7 +74,7 @@ function nllorenz!(du, u, p)
7374
lorenz!(du, u, p, 0.0)
7475
end
7576

76-
fn = NonlinearFunction(nllorenz!; sys)
77+
fn = NonlinearFunction(nllorenz!; sys = indep_sys)
7778
for T in containerTypes
7879
push!(probs, NonlinearProblem(fn, u0, T(p)))
7980
end

0 commit comments

Comments
 (0)