Skip to content

Commit a8ed1c8

Browse files
Merge pull request #3648 from AayushSabharwal/as/problem-refactor-2
refactor: change problem constructors to `XProblem(sys, op[, tspan])`
2 parents a049d76 + 4d79ef4 commit a8ed1c8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+586
-573
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,10 @@ function MTK.CasADiDynamicOptProblem(sys::System, u0map, tspan, pmap;
7676
steps = nothing,
7777
guesses = Dict(), kwargs...)
7878
MTK.warn_overdetermined(sys, u0map)
79-
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
80-
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
79+
_u0map = has_alg_eqs(sys) ? MTK.to_varmap(u0map, unknowns(sys)) :
80+
merge(Dict(u0map), Dict(guesses))
81+
pmap = MTK.to_varmap(pmap, parameters(sys))
82+
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, merge(_u0map, pmap);
8183
t = tspan !== nothing ? tspan[1] : tspan, output_type = MX, kwargs...)
8284

8385
pmap = MTK.recursive_unwrap(MTK.AnyDict(pmap))

ext/MTKInfiniteOptExt.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,10 @@ function MTK.JuMPDynamicOptProblem(sys::System, u0map, tspan, pmap;
5959
steps = nothing,
6060
guesses = Dict(), kwargs...)
6161
MTK.warn_overdetermined(sys, u0map)
62-
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
63-
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
62+
_u0map = has_alg_eqs(sys) ? MTK.to_varmap(u0map, unknowns(sys)) :
63+
merge(Dict(u0map), Dict(guesses))
64+
pmap = MTK.to_varmap(pmap, parameters(sys))
65+
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, merge(_u0map, pmap);
6466
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
6567

6668
pmap = MTK.recursive_unwrap(MTK.AnyDict(pmap))
@@ -86,8 +88,10 @@ function MTK.InfiniteOptDynamicOptProblem(sys::System, u0map, tspan, pmap;
8688
steps = nothing,
8789
guesses = Dict(), kwargs...)
8890
MTK.warn_overdetermined(sys, u0map)
89-
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
90-
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, _u0map, pmap;
91+
_u0map = has_alg_eqs(sys) ? MTK.to_varmap(u0map, unknowns(sys)) :
92+
merge(Dict(u0map), Dict(guesses))
93+
pmap = MTK.to_varmap(pmap, parameters(sys))
94+
f, u0, p = MTK.process_SciMLProblem(ODEInputFunction, sys, merge(_u0map, pmap);
9195
t = tspan !== nothing ? tspan[1] : tspan, kwargs...)
9296

9397
pmap = MTK.recursive_unwrap(MTK.AnyDict(pmap))

src/ModelingToolkit.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ PrecompileTools.@compile_workload begin
234234
using ModelingToolkit
235235
@variables x(ModelingToolkit.t_nounits)
236236
@named sys = System([ModelingToolkit.D_nounits(x) ~ -x], ModelingToolkit.t_nounits)
237-
prob = ODEProblem(mtkcompile(sys), [x => 30.0], (0, 100), [], jac = true)
237+
prob = ODEProblem(mtkcompile(sys), [x => 30.0], (0, 100), jac = true)
238238
@mtkmodel __testmod__ begin
239239
@constants begin
240240
c = 1.0

src/deprecations.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@ macro mtkbuild(exprs...)
88
@mtkcompile $(exprs...)
99
end |> esc
1010
end
11-

src/linearization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ function linearization_function(sys::AbstractSystem, inputs,
7373
end
7474

7575
prob = ODEProblem{true, SciMLBase.FullSpecialize}(
76-
sys, op, (nothing, nothing), p; allow_incomplete = true,
76+
sys, merge(op, anydict(p)), (nothing, nothing); allow_incomplete = true,
7777
algebraic_only = true, guesses)
7878
u0 = state_values(prob)
7979

src/problems/bvproblem.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ If the `System` has algebraic equations, like `x(t)^2 + y(t)^2`, the resulting
4444
`BVProblem` must be solved using BVDAE solvers, such as Ascher.
4545
"""
4646
@fallback_iip_specialize function SciMLBase.BVProblem{iip, spec}(
47-
sys::System, u0map, tspan, parammap = SciMLBase.NullParameters();
47+
sys::System, op, tspan;
4848
check_compatibility = true, cse = true,
4949
checkbounds = false, eval_expression = false, eval_module = @__MODULE__,
5050
expression = Val{false}, guesses = Dict(), callback = nothing,
@@ -55,22 +55,23 @@ If the `System` has algebraic equations, like `x(t)^2 + y(t)^2`, the resulting
5555

5656
# Systems without algebraic equations should use both fixed values + guesses
5757
# for initialization.
58-
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
58+
_op = has_alg_eqs(sys) ? op : merge(Dict(op), Dict(guesses))
5959

6060
fode, u0, p = process_SciMLProblem(
61-
ODEFunction{iip, spec}, sys, _u0map, parammap; guesses,
61+
ODEFunction{iip, spec}, sys, _op; guesses,
6262
t = tspan !== nothing ? tspan[1] : tspan, check_compatibility = false, cse,
6363
checkbounds, time_dependent_init = false, expression, kwargs...)
6464

6565
dvs = unknowns(sys)
6666
stidxmap = Dict([v => i for (i, v) in enumerate(dvs)])
67-
u0_idxs = has_alg_eqs(sys) ? collect(1:length(dvs)) : [stidxmap[k] for (k, v) in u0map]
67+
u0_idxs = has_alg_eqs(sys) ? collect(1:length(dvs)) :
68+
[stidxmap[k] for (k, v) in op if haskey(stidxmap, k)]
6869
fbc = generate_boundary_conditions(
6970
sys, u0, u0_idxs, tspan[1]; expression = Val{false},
7071
wrap_gfw = Val{true}, cse, checkbounds)
7172

72-
if (length(constraints(sys)) + length(u0map) > length(dvs))
73-
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
73+
if (length(constraints(sys)) + length(op) > length(dvs))
74+
@warn "The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by op) exceeds the total number of states. The BVP solvers will default to doing a nonlinear least-squares optimization."
7475
end
7576

7677
kwargs = process_kwargs(sys; expression, kwargs...)

src/problems/daeproblem.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,15 @@
6262
end
6363

6464
@fallback_iip_specialize function SciMLBase.DAEProblem{iip, spec}(
65-
sys::System, du0map, u0map, tspan, parammap = SciMLBase.NullParameters();
65+
sys::System, op, tspan;
6666
callback = nothing, check_length = true, eval_expression = false,
6767
eval_module = @__MODULE__, check_compatibility = true,
6868
expression = Val{false}, kwargs...) where {iip, spec}
6969
check_complete(sys, DAEProblem)
7070
check_compatibility && check_compatible_system(DAEProblem, sys)
7171

72-
f, du0, u0, p = process_SciMLProblem(DAEFunction{iip, spec}, sys, u0map, parammap;
73-
du0map, t = tspan !== nothing ? tspan[1] : tspan, check_length, eval_expression,
72+
f, du0, u0, p = process_SciMLProblem(DAEFunction{iip, spec}, sys, op;
73+
t = tspan !== nothing ? tspan[1] : tspan, check_length, eval_expression,
7474
eval_module, check_compatibility, implicit_dae = true, expression, kwargs...)
7575

7676
kwargs = process_kwargs(sys; expression, callback, eval_expression, eval_module,

src/problems/ddeproblem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,14 @@
4242
end
4343

4444
@fallback_iip_specialize function SciMLBase.DDEProblem{iip, spec}(
45-
sys::System, u0map, tspan, parammap = SciMLBase.NullParameters();
45+
sys::System, op, tspan;
4646
callback = nothing, check_length = true, cse = true, checkbounds = false,
4747
eval_expression = false, eval_module = @__MODULE__, check_compatibility = true,
4848
u0_constructor = identity, expression = Val{false}, kwargs...) where {iip, spec}
4949
check_complete(sys, DDEProblem)
5050
check_compatibility && check_compatible_system(DDEProblem, sys)
5151

52-
f, u0, p = process_SciMLProblem(DDEFunction{iip, spec}, sys, u0map, parammap;
52+
f, u0, p = process_SciMLProblem(DDEFunction{iip, spec}, sys, op;
5353
t = tspan !== nothing ? tspan[1] : tspan, check_length, cse, checkbounds,
5454
eval_expression, eval_module, check_compatibility, symbolic_u0 = true,
5555
expression, u0_constructor, kwargs...)

src/problems/discreteproblem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@
3939
end
4040

4141
@fallback_iip_specialize function SciMLBase.DiscreteProblem{iip, spec}(
42-
sys::System, u0map, tspan, parammap = SciMLBase.NullParameters();
42+
sys::System, op, tspan;
4343
check_compatibility = true, expression = Val{false}, kwargs...) where {iip, spec}
4444
check_complete(sys, DiscreteProblem)
4545
check_compatibility && check_compatible_system(DiscreteProblem, sys)
4646

4747
dvs = unknowns(sys)
4848
u0map = to_varmap(u0map, dvs)
4949
add_toterms!(u0map; replace = true)
50-
f, u0, p = process_SciMLProblem(DiscreteFunction{iip, spec}, sys, u0map, parammap;
50+
f, u0, p = process_SciMLProblem(DiscreteFunction{iip, spec}, sys, op;
5151
t = tspan !== nothing ? tspan[1] : tspan, check_compatibility, expression,
5252
kwargs...)
5353

src/problems/implicitdiscreteproblem.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,16 @@
4343
end
4444

4545
@fallback_iip_specialize function SciMLBase.ImplicitDiscreteProblem{iip, spec}(
46-
sys::System, u0map, tspan, parammap = SciMLBase.NullParameters();
46+
sys::System, op, tspan;
4747
check_compatibility = true, expression = Val{false}, kwargs...) where {iip, spec}
4848
check_complete(sys, ImplicitDiscreteProblem)
4949
check_compatibility && check_compatible_system(ImplicitDiscreteProblem, sys)
5050

5151
dvs = unknowns(sys)
52-
u0map = to_varmap(u0map, dvs)
53-
add_toterms!(u0map; replace = true)
52+
op = to_varmap(op, dvs)
53+
add_toterms!(op; replace = true)
5454
f, u0, p = process_SciMLProblem(
55-
ImplicitDiscreteFunction{iip, spec}, sys, u0map, parammap;
55+
ImplicitDiscreteFunction{iip, spec}, sys, op;
5656
t = tspan !== nothing ? tspan[1] : tspan, check_compatibility,
5757
expression, kwargs...)
5858

0 commit comments

Comments
 (0)