Skip to content

Commit debd300

Browse files
committed
output extra information from linearization
1 parent 1fcf103 commit debd300

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

src/linearization.jl

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ end
157157
"""
158158
$(TYPEDSIGNATURES)
159159
160-
Linearize the wrapped system at the point given by `(u, p, t)`.
160+
Linearize the wrapped system at the point given by `(unknowns, p, t)`.
161161
"""
162162
function (linfun::LinearizationFunction)(u, p, t)
163163
if eltype(p) <: Pair
@@ -182,7 +182,7 @@ function (linfun::LinearizationFunction)(u, p, t)
182182
linfun.prob, integ, fun, linfun.initializealg, Val(true);
183183
linfun.initialize_kwargs...)
184184
if !success
185-
error("Initialization algorithm $(linfun.initializealg) failed with `u = $u` and `p = $p`.")
185+
error("Initialization algorithm $(linfun.initializealg) failed with `unknowns = $u` and `p = $p`.")
186186
end
187187
uf = SciMLBase.UJacobianWrapper(fun, t, p)
188188
fg_xz = ForwardDiff.jacobian(uf, u)
@@ -211,7 +211,11 @@ function (linfun::LinearizationFunction)(u, p, t)
211211
g_u = fg_u[linfun.alge_idxs, :],
212212
h_x = h_xz[:, linfun.diff_idxs],
213213
h_z = h_xz[:, linfun.alge_idxs],
214-
h_u = h_u)
214+
h_u = h_u,
215+
x = u,
216+
p,
217+
t,
218+
success)
215219
end
216220

217221
"""
@@ -319,7 +323,7 @@ function CommonSolve.solve(prob::LinearizationProblem; allow_input_derivatives =
319323
p = parameter_values(prob)
320324
t = current_time(prob)
321325
linres = prob.f(u0, p, t)
322-
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u = linres
326+
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u, x, p, t, success = linres
323327

324328
nx, nu = size(f_u)
325329
nz = size(f_z, 2)
@@ -356,7 +360,7 @@ function CommonSolve.solve(prob::LinearizationProblem; allow_input_derivatives =
356360
end
357361
end
358362

359-
(; A, B, C, D)
363+
(; A, B, C, D), (; x, p, t, success)
360364
end
361365

362366
"""
@@ -487,8 +491,8 @@ function markio!(state, orig_inputs, inputs, outputs; check = true)
487491
end
488492

489493
"""
490-
(; A, B, C, D), simplified_sys = linearize(sys, inputs, outputs; t=0.0, op = Dict(), allow_input_derivatives = false, zero_dummy_der=false, kwargs...)
491-
(; A, B, C, D) = linearize(simplified_sys, lin_fun; t=0.0, op = Dict(), allow_input_derivatives = false, zero_dummy_der=false)
494+
(; A, B, C, D), simplified_sys, extras = linearize(sys, inputs, outputs; t=0.0, op = Dict(), allow_input_derivatives = false, zero_dummy_der=false, kwargs...)
495+
(; A, B, C, D), extras = linearize(simplified_sys, lin_fun; t=0.0, op = Dict(), allow_input_derivatives = false, zero_dummy_der=false)
492496
493497
Linearize `sys` between `inputs` and `outputs`, both vectors of variables. Return a NamedTuple with the matrices of a linear statespace representation
494498
on the form
@@ -510,6 +514,8 @@ If `allow_input_derivatives = false`, an error will be thrown if input derivativ
510514
511515
`zero_dummy_der` can be set to automatically set the operating point to zero for all dummy derivatives.
512516
517+
The return value `extras` is a NamedTuple `(; x, p, t, success)` containing the result of the initialization problem that was solved to determine the operating point.
518+
513519
See also [`linearization_function`](@ref) which provides a lower-level interface, [`linearize_symbolic`](@ref) and [`ModelingToolkit.reorder_unknowns`](@ref).
514520
515521
See extended help for an example.
@@ -616,7 +622,8 @@ function linearize(sys, inputs, outputs; op = Dict(), t = 0.0,
616622
zero_dummy_der,
617623
op,
618624
kwargs...)
619-
linearize(ssys, lin_fun; op, t, allow_input_derivatives), ssys
625+
mats, extras = linearize(ssys, lin_fun; op, t, allow_input_derivatives)
626+
mats, ssys, extras
620627
end
621628

622629
"""

test/downstream/linearize.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@ eqs = [u ~ kp * (r - y)
1414

1515
@named sys = ODESystem(eqs, t)
1616

17-
lsys, ssys = linearize(sys, [r], [y])
17+
lsys, ssys, extras = linearize(sys, [r], [y])
1818
lprob = LinearizationProblem(sys, [r], [y])
19-
lsys2 = solve(lprob)
19+
lsys2, extras2 = solve(lprob)
2020

2121
@test lsys.A[] == lsys2.A[] == -2
2222
@test lsys.B[] == lsys2.B[] == 1
2323
@test lsys.C[] == lsys2.C[] == 1
2424
@test lsys.D[] == lsys2.D[] == 0
25+
@test extras == extras2
2526

2627
lsys, ssys = linearize(sys, [r], [r])
2728

0 commit comments

Comments
 (0)