Skip to content

Commit 70f901f

Browse files
Merge pull request #3060 from AayushSabharwal/as/odeprob-nothing
fix: handle `nothing` passed as `u0` to `ODEProblem`
2 parents 90e6398 + f84e571 commit 70f901f

File tree

7 files changed

+47
-2
lines changed

7 files changed

+47
-2
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -817,11 +817,12 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
817817
ModelingToolkit.get_tearing_state(sys) !== nothing) ||
818818
!isempty(initialization_equations(sys))) && t !== nothing
819819
if eltype(u0map) <: Number
820-
u0map = unknowns(sys) .=> u0map
820+
u0map = unknowns(sys) .=> vec(u0map)
821821
end
822-
if isempty(u0map)
822+
if u0map === nothing || isempty(u0map)
823823
u0map = Dict()
824824
end
825+
825826
initializeprob = ModelingToolkit.InitializationProblem(
826827
sys, t, u0map, parammap; guesses, warn_initialize_determined,
827828
initialization_eqs, eval_expression, eval_module, fully_determined, check_units)

src/systems/discrete_system/discrete_system.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,13 @@ function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, paramm
247247
dvs = unknowns(sys)
248248
ps = parameters(sys)
249249

250+
if eltype(u0map) <: Number
251+
u0map = unknowns(sys) .=> vec(u0map)
252+
end
253+
if u0map === nothing || isempty(u0map)
254+
u0map = Dict()
255+
end
256+
250257
trueu0map = Dict()
251258
for (k, v) in u0map
252259
k = unwrap(k)

test/discrete_system.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,3 +271,11 @@ end
271271
k = ShiftIndex(t)
272272
@named sys = DiscreteSystem([x ~ x^2 + y^2, y ~ x(k - 1) + y(k - 1)], t)
273273
@test_throws ["algebraic equations", "not yet supported"] structural_simplify(sys)
274+
275+
@testset "Passing `nothing` to `u0`" begin
276+
@variables x(t) = 1
277+
k = ShiftIndex()
278+
@mtkbuild sys = DiscreteSystem([x(k) ~ x(k - 1) + 1], t)
279+
prob = @test_nowarn DiscreteProblem(sys, nothing, (0.0, 1.0))
280+
@test_nowarn solve(prob, FunctionMap())
281+
end

test/nonlinearsystem.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,3 +318,10 @@ sys = structural_simplify(ns; conservative = true)
318318
sol = solve(prob, NewtonRaphson())
319319
@test sol[x] sol[y] sol[z] -3
320320
end
321+
322+
@testset "Passing `nothing` to `u0`" begin
323+
@variables x = 1
324+
@mtkbuild sys = NonlinearSystem([0 ~ x^2 - x^3 + 3])
325+
prob = @test_nowarn NonlinearProblem(sys, nothing)
326+
@test_nowarn solve(prob)
327+
end

test/odesystem.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,3 +1387,10 @@ end
13871387

13881388
@test obsfn(ones(2), 2ones(2), 3ones(4), 4.0) == 6ones(2)
13891389
end
1390+
1391+
@testset "Passing `nothing` to `u0`" begin
1392+
@variables x(t) = 1
1393+
@mtkbuild sys = ODESystem(D(x) ~ t, t)
1394+
prob = @test_nowarn ODEProblem(sys, nothing, (0.0, 1.0))
1395+
@test_nowarn solve(prob)
1396+
end

test/optimizationsystem.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,10 @@ end
340340
prob.f.cons_h(H3, [1.0, 1.0], [1.0, 100.0])
341341
@test prob.f.cons_h([1.0, 1.0], [1.0, 100.0]) == H3
342342
end
343+
344+
@testset "Passing `nothing` to `u0`" begin
345+
@variables x = 1.0
346+
@mtkbuild sys = OptimizationSystem((x - 3)^2, [x], [])
347+
prob = @test_nowarn OptimizationProblem(sys, nothing)
348+
@test_nowarn solve(prob, NelderMead())
349+
end

test/sdesystem.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,3 +776,11 @@ end
776776
prob = SDEProblem(de, u0map, (0.0, 100.0), parammap)
777777
@test solve(prob, SOSRI()).retcode == ReturnCode.Success
778778
end
779+
780+
@testset "Passing `nothing` to `u0`" begin
781+
@variables x(t) = 1
782+
@brownian b
783+
@mtkbuild sys = System([D(x) ~ x + b], t)
784+
prob = @test_nowarn SDEProblem(sys, nothing, (0.0, 1.0))
785+
@test_nowarn solve(prob, ImplicitEM())
786+
end

0 commit comments

Comments
 (0)