Skip to content

Commit d838bbf

Browse files
Merge pull request #3121 from AayushSabharwal/as/cleanup
refactor: major cleanup of `*Problem` construction
2 parents 1f53f6a + 1861727 commit d838bbf

15 files changed

+686
-430
lines changed

src/ModelingToolkit.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ include("systems/abstractsystem.jl")
144144
include("systems/model_parsing.jl")
145145
include("systems/connectors.jl")
146146
include("systems/callbacks.jl")
147+
include("systems/problem_utils.jl")
147148

148149
include("systems/nonlinear/nonlinearsystem.jl")
149150
include("systems/diffeqs/odesystem.jl")

src/systems/abstractsystem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2914,7 +2914,7 @@ function Base.eltype(::Type{<:TreeIterator{ModelingToolkit.AbstractSystem}})
29142914
end
29152915

29162916
function check_array_equations_unknowns(eqs, dvs)
2917-
if any(eq -> Symbolics.isarraysymbolic(eq.lhs), eqs)
2917+
if any(eq -> eq isa Equation && Symbolics.isarraysymbolic(eq.lhs), eqs)
29182918
throw(ArgumentError("The system has array equations. Call `structural_simplify` to handle such equations or scalarize them manually."))
29192919
end
29202920
if any(x -> Symbolics.isarraysymbolic(x), dvs)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 9 additions & 304 deletions
Large diffs are not rendered by default.

src/systems/diffeqs/sdesystem.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,7 @@ function DiffEqBase.SDEProblem{iip, specialize}(
659659
if !iscomplete(sys)
660660
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEProblem`")
661661
end
662-
f, u0, p = process_DEProblem(
662+
f, u0, p = process_SciMLProblem(
663663
SDEFunction{iip, specialize}, sys, u0map, parammap; check_length,
664664
kwargs...)
665665
cbs = process_events(sys; callback, kwargs...)
@@ -745,7 +745,8 @@ function SDEProblemExpr{iip}(sys::SDESystem, u0map, tspan,
745745
if !iscomplete(sys)
746746
error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEProblemExpr`")
747747
end
748-
f, u0, p = process_DEProblem(SDEFunctionExpr{iip}, sys, u0map, parammap; check_length,
748+
f, u0, p = process_SciMLProblem(
749+
SDEFunctionExpr{iip}, sys, u0map, parammap; check_length,
749750
kwargs...)
750751
linenumbers = get(kwargs, :linenumbers, true)
751752
sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false))

src/systems/discrete_system/discrete_system.jl

Lines changed: 16 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -233,55 +233,25 @@ function generate_function(
233233
generate_custom_function(sys, exprs, dvs, ps; wrap_code, kwargs...)
234234
end
235235

236-
function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, parammap;
237-
linenumbers = true, parallel = SerialForm(),
238-
use_union = false,
239-
tofloat = !use_union,
240-
eval_expression = false, eval_module = @__MODULE__,
241-
kwargs...)
236+
function shift_u0map_forward(sys::DiscreteSystem, u0map, defs)
242237
iv = get_iv(sys)
243-
eqs = equations(sys)
244-
dvs = unknowns(sys)
245-
ps = parameters(sys)
246-
247-
if eltype(u0map) <: Number
248-
u0map = unknowns(sys) .=> vec(u0map)
249-
end
250-
if u0map === nothing || isempty(u0map)
251-
u0map = Dict()
252-
end
253-
254-
trueu0map = Dict()
255-
for (k, v) in u0map
256-
k = unwrap(k)
238+
updated = AnyDict()
239+
for k in collect(keys(u0map))
240+
v = u0map[k]
257241
if !((op = operation(k)) isa Shift)
258242
error("Initial conditions must be for the past state of the unknowns. Instead of providing the condition for $k, provide the condition for $(Shift(iv, -1)(k)).")
259243
end
260-
trueu0map[Shift(iv, op.steps + 1)(arguments(k)[1])] = v
261-
end
262-
defs = ModelingToolkit.get_defaults(sys)
263-
for var in dvs
264-
if (op = operation(var)) isa Shift && !haskey(trueu0map, var)
265-
root = arguments(var)[1]
266-
haskey(defs, root) || error("Initial condition for $var not provided.")
267-
trueu0map[var] = defs[root]
268-
end
244+
updated[Shift(iv, op.steps + 1)(arguments(k)[1])] = v
269245
end
270-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
271-
u0, defs = get_u0(sys, trueu0map, parammap)
272-
p = MTKParameters(sys, parammap, trueu0map)
273-
else
274-
u0, p, defs = get_u0_p(sys, trueu0map, parammap; tofloat, use_union)
246+
for var in unknowns(sys)
247+
op = operation(var)
248+
op isa Shift || continue
249+
haskey(updated, var) && continue
250+
root = first(arguments(var))
251+
haskey(defs, root) || error("Initial condition for $var not provided.")
252+
updated[var] = defs[root]
275253
end
276-
277-
check_eqs_u0(eqs, dvs, u0; kwargs...)
278-
279-
f = constructor(sys, dvs, ps, u0;
280-
linenumbers = linenumbers, parallel = parallel,
281-
syms = Symbol.(dvs), paramsyms = Symbol.(ps),
282-
eval_expression = eval_expression, eval_module = eval_module,
283-
kwargs...)
284-
return f, u0, p
254+
return updated
285255
end
286256

287257
"""
@@ -304,7 +274,9 @@ function SciMLBase.DiscreteProblem(
304274
eqs = equations(sys)
305275
iv = get_iv(sys)
306276

307-
f, u0, p = process_DiscreteProblem(
277+
u0map = to_varmap(u0map, dvs)
278+
u0map = shift_u0map_forward(sys, u0map, defaults(sys))
279+
f, u0, p = process_SciMLProblem(
308280
DiscreteFunction, sys, u0map, parammap; eval_expression, eval_module)
309281
u0 = f(u0, p, tspan[1])
310282
DiscreteProblem(f, u0, tspan, p; kwargs...)

src/systems/jumps/jumpsystem.jl

Lines changed: 6 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -348,20 +348,8 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
348348
if !iscomplete(sys)
349349
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
350350
end
351-
dvs = unknowns(sys)
352-
ps = parameters(sys)
353-
354-
defs = defaults(sys)
355-
defs = mergedefaults(defs, parammap, ps)
356-
defs = mergedefaults(defs, u0map, dvs)
357-
358-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
359-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
360-
p = MTKParameters(sys, parammap, u0map)
361-
else
362-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
363-
end
364-
351+
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
352+
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
365353
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
366354

367355
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
@@ -399,16 +387,9 @@ function DiscreteProblemExpr{iip}(sys::JumpSystem, u0map, tspan::Union{Tuple, No
399387
if !iscomplete(sys)
400388
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblemExpr`")
401389
end
402-
dvs = unknowns(sys)
403-
ps = parameters(sys)
404-
defs = defaults(sys)
405390

406-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
407-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
408-
p = MTKParameters(sys, parammap, u0map)
409-
else
410-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
411-
end
391+
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
392+
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
412393
# identity function to make syms works
413394
quote
414395
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT
@@ -454,19 +435,9 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
454435
if !iscomplete(sys)
455436
error("A completed `JumpSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `DiscreteProblem`")
456437
end
457-
dvs = unknowns(sys)
458-
ps = parameters(sys)
459-
460-
defs = defaults(sys)
461-
defs = mergedefaults(defs, parammap, ps)
462-
defs = mergedefaults(defs, u0map, dvs)
463438

464-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = false)
465-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
466-
p = MTKParameters(sys, parammap, u0map)
467-
else
468-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = false, use_union)
469-
end
439+
_, u0, p = process_SciMLProblem(EmptySciMLFunction, sys, u0map, parammap;
440+
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
470441

471442
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
472443

src/systems/nonlinear/initializesystem.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function generate_initializesystem(sys::ODESystem;
3737
# set dummy derivatives to default_dd_guess unless specified
3838
push!(defs, x[1] => get(guesses, x[1], default_dd_guess))
3939
end
40-
for (y, x) in u0map
40+
function process_u0map_with_dummysubs(y, x)
4141
y = get(schedule.dummy_sub, y, y)
4242
y = fixpoint_sub(y, diffmap)
4343
if y vars_set
@@ -53,6 +53,13 @@ function generate_initializesystem(sys::ODESystem;
5353
error("Initialization expression $y is currently not supported. If its a higher order derivative expression, then only the dummy derivative expressions are supported.")
5454
end
5555
end
56+
for (y, x) in u0map
57+
if Symbolics.isarraysymbolic(y)
58+
process_u0map_with_dummysubs.(collect(y), collect(x))
59+
else
60+
process_u0map_with_dummysubs(y, x)
61+
end
62+
end
5663
end
5764

5865
# 2) process other variables

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ function SciMLBase.NonlinearFunction(sys::NonlinearSystem, args...; kwargs...)
290290
end
291291

292292
function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
293-
ps = parameters(sys), u0 = nothing, p = nothing;
293+
ps = parameters(sys), u0 = nothing; p = nothing,
294294
version = nothing,
295295
jac = false,
296296
eval_expression = false,
@@ -405,36 +405,6 @@ function NonlinearFunctionExpr{iip}(sys::NonlinearSystem, dvs = unknowns(sys),
405405
!linenumbers ? Base.remove_linenums!(ex) : ex
406406
end
407407

408-
function process_NonlinearProblem(constructor, sys::NonlinearSystem, u0map, parammap;
409-
version = nothing,
410-
jac = false,
411-
checkbounds = false, sparse = false,
412-
simplify = false,
413-
linenumbers = true, parallel = SerialForm(),
414-
eval_expression = false,
415-
eval_module = @__MODULE__,
416-
use_union = false,
417-
tofloat = !use_union,
418-
kwargs...)
419-
eqs = equations(sys)
420-
dvs = unknowns(sys)
421-
ps = parameters(sys)
422-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
423-
u0, defs = get_u0(sys, u0map, parammap)
424-
check_eqs_u0(eqs, dvs, u0; kwargs...)
425-
p = MTKParameters(sys, parammap, u0map)
426-
else
427-
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union)
428-
check_eqs_u0(eqs, dvs, u0; kwargs...)
429-
end
430-
431-
f = constructor(sys, dvs, ps, u0, p; jac = jac, checkbounds = checkbounds,
432-
linenumbers = linenumbers, parallel = parallel, simplify = simplify,
433-
sparse = sparse, eval_expression = eval_expression, eval_module = eval_module,
434-
kwargs...)
435-
return f, u0, p
436-
end
437-
438408
"""
439409
```julia
440410
DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
@@ -458,7 +428,7 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map,
458428
if !iscomplete(sys)
459429
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblem`")
460430
end
461-
f, u0, p = process_NonlinearProblem(NonlinearFunction{iip}, sys, u0map, parammap;
431+
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
462432
check_length, kwargs...)
463433
pt = something(get_metadata(sys), StandardNonlinearProblem())
464434
NonlinearProblem{iip}(f, u0, p, pt; filter_kwargs(kwargs)...)
@@ -487,7 +457,7 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
487457
if !iscomplete(sys)
488458
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearLeastSquaresProblem`")
489459
end
490-
f, u0, p = process_NonlinearProblem(NonlinearFunction{iip}, sys, u0map, parammap;
460+
f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap;
491461
check_length, kwargs...)
492462
pt = something(get_metadata(sys), StandardNonlinearProblem())
493463
NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...)
@@ -520,7 +490,7 @@ function NonlinearProblemExpr{iip}(sys::NonlinearSystem, u0map,
520490
if !iscomplete(sys)
521491
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblemExpr`")
522492
end
523-
f, u0, p = process_NonlinearProblem(NonlinearFunctionExpr{iip}, sys, u0map, parammap;
493+
f, u0, p = process_SciMLProblem(NonlinearFunctionExpr{iip}, sys, u0map, parammap;
524494
check_length, kwargs...)
525495
linenumbers = get(kwargs, :linenumbers, true)
526496

@@ -560,7 +530,7 @@ function NonlinearLeastSquaresProblemExpr{iip}(sys::NonlinearSystem, u0map,
560530
if !iscomplete(sys)
561531
error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblemExpr`")
562532
end
563-
f, u0, p = process_NonlinearProblem(NonlinearFunctionExpr{iip}, sys, u0map, parammap;
533+
f, u0, p = process_SciMLProblem(NonlinearFunctionExpr{iip}, sys, u0map, parammap;
564534
check_length, kwargs...)
565535
linenumbers = get(kwargs, :linenumbers, true)
566536

0 commit comments

Comments
 (0)