Skip to content

Commit aff4ae3

Browse files
authored
Merge pull request #2102 from SciML/fb/linear_opt
Optimize `linearize` by avoiding function generation
2 parents 97a0cb8 + 65a6f3e commit aff4ae3

File tree

7 files changed

+30
-21
lines changed

7 files changed

+30
-21
lines changed

docs/src/systems/DiscreteSystem.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ DiscreteSystem
1111
- `get_eqs(sys)` or `equations(sys)`: The equations that define the Discrete System.
1212
- `get_delay_val(sys)`: The delay of the Discrete System.
1313
- `get_iv(sys)`: The independent variable of the Discrete System.
14+
- `get_u0_p(sys, u0map, parammap)` Numeric arrays for the initial condition and parameters given `var => value` maps.
1415

1516
## Transformations
1617

docs/src/systems/NonlinearSystem.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ NonlinearSystem
1111
- `get_eqs(sys)` or `equations(sys)`: The equations that define the nonlinear system.
1212
- `get_states(sys)` or `states(sys)`: The set of states in the nonlinear system.
1313
- `get_ps(sys)` or `parameters(sys)`: The parameters of the nonlinear system.
14+
- `get_u0_p(sys, u0map, parammap)` Numeric arrays for the initial condition and parameters given `var => value` maps.
1415

1516
## Transformations
1617

docs/src/systems/ODESystem.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ ODESystem
1212
- `get_states(sys)` or `states(sys)`: The set of states in the ODE.
1313
- `get_ps(sys)` or `parameters(sys)`: The parameters of the ODE.
1414
- `get_iv(sys)`: The independent variable of the ODE.
15+
- `get_u0_p(sys, u0map, parammap)` Numeric arrays for the initial condition and parameters given `var => value` maps.
1516

1617
## Transformations
1718

src/systems/abstractsystem.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,9 +1386,9 @@ lsys = ModelingToolkit.reorder_states(lsys, states(ssys), desired_order)
13861386
function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives = false,
13871387
p = DiffEqBase.NullParameters())
13881388
x0 = merge(defaults(sys), op)
1389-
f, u0, p = process_DEProblem(ODEFunction{true}, sys, x0, p)
1389+
u0, p2, _ = get_u0_p(sys, x0, p; use_union = false, tofloat = true)
13901390

1391-
linres = lin_fun(u0, p, t)
1391+
linres = lin_fun(u0, p2, t)
13921392
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u = linres
13931393

13941394
nx, nu = size(f_u)

src/systems/diffeqs/abstractodesystem.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,26 @@ function ODEFunctionExpr{iip}(sys::AbstractODESystem, dvs = states(sys),
570570
!linenumbers ? striplines(ex) : ex
571571
end
572572

573+
"""
574+
u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=false, tofloat=!use_union)
575+
576+
Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point.
577+
"""
578+
function get_u0_p(sys, u0map, parammap; use_union = false, tofloat = !use_union)
579+
eqs = equations(sys)
580+
dvs = states(sys)
581+
ps = parameters(sys)
582+
583+
defs = defaults(sys)
584+
defs = mergedefaults(defs, parammap, ps)
585+
defs = mergedefaults(defs, u0map, dvs)
586+
587+
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
588+
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
589+
p = p === nothing ? SciMLBase.NullParameters() : p
590+
u0, p, defs
591+
end
592+
573593
function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
574594
implicit_dae = false, du0map = nothing,
575595
version = nothing, tgrad = false,
@@ -586,13 +606,7 @@ function process_DEProblem(constructor, sys::AbstractODESystem, u0map, parammap;
586606
ps = parameters(sys)
587607
iv = get_iv(sys)
588608

589-
defs = defaults(sys)
590-
defs = mergedefaults(defs, parammap, ps)
591-
defs = mergedefaults(defs, u0map, dvs)
592-
593-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
594-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat, use_union)
595-
p = p === nothing ? SciMLBase.NullParameters() : p
609+
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union)
596610

597611
if implicit_dae && du0map !== nothing
598612
ddvs = map(Differential(iv), dvs)

src/systems/discrete_system/discrete_system.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -418,17 +418,13 @@ function process_DiscreteProblem(constructor, sys::DiscreteSystem, u0map, paramm
418418
linenumbers = true, parallel = SerialForm(),
419419
eval_expression = true,
420420
use_union = false,
421+
tofloat = !use_union,
421422
kwargs...)
422423
eqs = equations(sys)
423424
dvs = states(sys)
424425
ps = parameters(sys)
425426

426-
defs = defaults(sys)
427-
defs = mergedefaults(defs, parammap, ps)
428-
defs = mergedefaults(defs, u0map, dvs)
429-
430-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
431-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = !use_union, use_union)
427+
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union)
432428

433429
check_eqs_u0(eqs, dvs, u0; kwargs...)
434430

src/systems/nonlinear/nonlinearsystem.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -320,17 +320,13 @@ function process_NonlinearProblem(constructor, sys::NonlinearSystem, u0map, para
320320
linenumbers = true, parallel = SerialForm(),
321321
eval_expression = true,
322322
use_union = false,
323+
tofloat = !use_union,
323324
kwargs...)
324325
eqs = equations(sys)
325326
dvs = states(sys)
326327
ps = parameters(sys)
327328

328-
defs = defaults(sys)
329-
defs = mergedefaults(defs, parammap, ps)
330-
defs = mergedefaults(defs, u0map, dvs)
331-
332-
u0 = varmap_to_vars(u0map, dvs; defaults = defs, tofloat = true)
333-
p = varmap_to_vars(parammap, ps; defaults = defs, tofloat = !use_union, use_union)
329+
u0, p, defs = get_u0_p(sys, u0map, parammap; tofloat, use_union)
334330

335331
check_eqs_u0(eqs, dvs, u0; kwargs...)
336332

0 commit comments

Comments
 (0)