Skip to content

Commit e5245e1

Browse files
feat: run trivial initialization in problem constructor
1 parent 4eb88e8 commit e5245e1

File tree

4 files changed

+37
-10
lines changed

4 files changed

+37
-10
lines changed

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
759759
kwargs1 = merge(kwargs1, (; tstops))
760760
end
761761

762-
return ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
762+
return remake(ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...))
763763
end
764764
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
765765

@@ -963,8 +963,9 @@ function DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan
963963
kwargs1 = merge(kwargs1, (; tstops))
964964
end
965965

966-
DAEProblem{iip}(f, du0, u0, tspan, p; differential_vars = differential_vars,
967-
kwargs..., kwargs1...)
966+
return remake(DAEProblem{iip}(
967+
f, du0, u0, tspan, p; differential_vars = differential_vars,
968+
kwargs..., kwargs1...))
968969
end
969970

970971
function generate_history(sys::AbstractODESystem, u0; expression = Val{false}, kwargs...)
@@ -1008,7 +1009,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10081009
if cbs !== nothing
10091010
kwargs1 = merge(kwargs1, (callback = cbs,))
10101011
end
1011-
DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...)
1012+
return remake(DDEProblem{iip}(f, u0, h, tspan, p; kwargs1..., kwargs...))
10121013
end
10131014

10141015
function DiffEqBase.SDDEProblem(sys::AbstractODESystem, args...; kwargs...)
@@ -1057,9 +1058,9 @@ function DiffEqBase.SDDEProblem{iip}(sys::AbstractODESystem, u0map = [],
10571058
else
10581059
noise_rate_prototype = zeros(eltype(u0), size(noiseeqs))
10591060
end
1060-
SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
1061+
return remake(SDDEProblem{iip}(f, f.g, u0, h, tspan, p;
10611062
noise_rate_prototype =
1062-
noise_rate_prototype, kwargs1..., kwargs...)
1063+
noise_rate_prototype, kwargs1..., kwargs...))
10631064
end
10641065

10651066
"""

src/systems/diffeqs/sdesystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -792,8 +792,8 @@ function DiffEqBase.SDEProblem{iip, specialize}(
792792

793793
kwargs = filter_kwargs(kwargs)
794794

795-
SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
796-
noise_rate_prototype = noise_rate_prototype, kwargs...)
795+
return remake(SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise,
796+
noise_rate_prototype = noise_rate_prototype, kwargs...))
797797
end
798798

799799
function DiffEqBase.SDEProblem(sys::ODESystem, args...; kwargs...)

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
519519
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
520520
check_length, kwargs...)
521521
pt = something(get_metadata(sys), StandardNonlinearProblem())
522-
NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...)
522+
return remake(NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...))
523523
end
524524

525525
"""
@@ -548,7 +548,7 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
548548
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
549549
check_length, kwargs...)
550550
pt = something(get_metadata(sys), StandardNonlinearProblem())
551-
NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...)
551+
return remake(NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...))
552552
end
553553

554554
const TypeT = Union{DataType, UnionAll}

test/initializationsystem.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,3 +1388,29 @@ end
13881388
integ1 = init(oprob1)
13891389
@test integ1[X1] 1.0
13901390
end
1391+
1392+
@testset "Trivial initialization is run on problem construction" begin
1393+
@variables _x(..) y(t)
1394+
@brownian a
1395+
@parameters tot
1396+
x = _x(t)
1397+
@testset "$Problem" for (Problem, lhs, rhs) in [
1398+
(ODEProblem, D, 0.0),
1399+
(SDEProblem, D, a),
1400+
(DDEProblem, D, _x(t - 0.1)),
1401+
(SDDEProblem, D, _x(t - 0.1) + a)
1402+
]
1403+
@mtkbuild sys = ModelingToolkit.System([lhs(x) ~ x + rhs, x + y ~ tot], t;
1404+
guesses = [tot => 1.0], defaults = [tot => missing])
1405+
prob = Problem(sys, [x => 1.0, y => 1.0], (0.0, 1.0))
1406+
@test prob.ps[tot] 2.0
1407+
end
1408+
@testset "$Problem" for Problem in [NonlinearProblem, NonlinearLeastSquaresProblem]
1409+
@parameters p1 p2
1410+
@mtkbuild sys = NonlinearSystem([x^2 + y^2 ~ p1, (x - 1)^2 + (y - 1)^2 ~ p2];
1411+
parameter_dependencies = [p2 ~ 2p1],
1412+
guesses = [p1 => 0.0], defaults = [p1 => missing])
1413+
prob = Problem(sys, [x => 1.0, y => 1.0], [p2 => 6.0])
1414+
@test prob.ps[p1] 3.0
1415+
end
1416+
end

0 commit comments

Comments
 (0)