Skip to content

Commit f0590ae

Browse files
feat: add LinearizationProblem
1 parent 382ce9b commit f0590ae

File tree

1 file changed

+76
-50
lines changed

1 file changed

+76
-50
lines changed

src/linearization.jl

Lines changed: 76 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,13 @@ struct LinearizationFunction{
141141
initialize_kwargs::IK
142142
end
143143

144+
SymbolicIndexingInterface.symbolic_container(f::LinearizationFunction) = f.prob
145+
SymbolicIndexingInterface.state_values(f::LinearizationFunction) = state_values(f.prob)
146+
function SymbolicIndexingInterface.parameter_values(f::LinearizationFunction)
147+
parameter_values(f.prob)
148+
end
149+
SymbolicIndexingInterface.current_time(f::LinearizationFunction) = current_time(f.prob)
150+
144151
function (linfun::LinearizationFunction)(u, p, t)
145152
if eltype(p) <: Pair
146153
p = todict(p)
@@ -234,6 +241,61 @@ SymbolicIndexingInterface.parameter_values(integ::MockIntegrator) = integ.p
234241
SymbolicIndexingInterface.current_time(integ::MockIntegrator) = integ.t
235242
SciMLBase.get_tmp_cache(integ::MockIntegrator) = integ.cache
236243

244+
mutable struct LinearizationProblem{F <: LinearizationFunction, T}
245+
const f::F
246+
t::T
247+
end
248+
249+
SymbolicIndexingInterface.symbolic_container(p::LinearizationProblem) = p.f
250+
SymbolicIndexingInterface.state_values(p::LinearizationProblem) = state_values(p.f)
251+
SymbolicIndexingInterface.parameter_values(p::LinearizationProblem) = parameter_values(p.f)
252+
SymbolicIndexingInterface.current_time(p::LinearizationProblem) = p.t
253+
254+
function CommonSolve.solve(prob::LinearizationProblem; allow_input_derivatives = false)
255+
u0 = state_values(prob)
256+
p = parameter_values(prob)
257+
t = current_time(prob)
258+
linres = prob.f(u0, p, t)
259+
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u = linres
260+
261+
nx, nu = size(f_u)
262+
nz = size(f_z, 2)
263+
ny = size(h_x, 1)
264+
265+
D = h_u
266+
267+
if isempty(g_z)
268+
A = f_x
269+
B = f_u
270+
C = h_x
271+
@assert iszero(g_x)
272+
@assert iszero(g_z)
273+
@assert iszero(g_u)
274+
else
275+
gz = lu(g_z; check = false)
276+
issuccess(gz) ||
277+
error("g_z not invertible, this indicates that the DAE is of index > 1.")
278+
gzgx = -(gz \ g_x)
279+
A = [f_x f_z
280+
gzgx*f_x gzgx*f_z]
281+
B = [f_u
282+
gzgx * f_u] # The cited paper has zeros in the bottom block, see derivation in https://github.com/SciML/ModelingToolkit.jl/pull/1691 for the correct formula
283+
284+
C = [h_x h_z]
285+
Bs = -(gz \ g_u) # This equation differ from the cited paper, the paper is likely wrong since their equaiton leads to a dimension mismatch.
286+
if !iszero(Bs)
287+
if !allow_input_derivatives
288+
der_inds = findall(vec(any(!=(0), Bs, dims = 1)))
289+
error("Input derivatives appeared in expressions (-g_z\\g_u != 0), the following inputs appeared differentiated: $(inputs(sys)[der_inds]). Call `linearize` with keyword argument `allow_input_derivatives = true` to allow this and have the returned `B` matrix be of double width ($(2nu)), where the last $nu inputs are the derivatives of the first $nu inputs.")
290+
end
291+
B = [B [zeros(nx, nu); Bs]]
292+
D = [D zeros(ny, nu)]
293+
end
294+
end
295+
296+
(; A, B, C, D)
297+
end
298+
237299
"""
238300
(; A, B, C, D), simplified_sys = linearize_symbolic(sys::AbstractSystem, inputs, outputs; simplify = false, allow_input_derivatives = false, kwargs...)
239301
@@ -460,60 +522,24 @@ lsys_sym, _ = ModelingToolkit.linearize_symbolic(cl, [f.u], [p.x])
460522
@assert substitute(lsys_sym.A, ModelingToolkit.defaults(cl)) == lsys.A
461523
```
462524
"""
463-
function linearize(sys, lin_fun; t = 0.0, op = Dict(), allow_input_derivatives = false,
525+
function linearize(sys, lin_fun::LinearizationFunction; t = 0.0,
526+
op = Dict(), allow_input_derivatives = false,
464527
p = DiffEqBase.NullParameters())
465-
x0 = merge(defaults(sys), Dict(missing_variable_defaults(sys)), op)
466-
u0, defs = get_u0(sys, x0, p)
467-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
468-
if p isa SciMLBase.NullParameters
469-
p = op
470-
elseif p isa Dict
471-
p = merge(p, op)
472-
elseif p isa Vector && eltype(p) <: Pair
473-
p = merge(Dict(p), op)
474-
elseif p isa Vector
475-
p = merge(Dict(parameters(sys) .=> p), op)
528+
prob = LinearizationProblem(lin_fun, t)
529+
op = anydict(op)
530+
evaluate_varmap!(op, unknowns(sys))
531+
for (k, v) in op
532+
if is_parameter(prob, Initial(k))
533+
setu(prob, Initial(k))(prob, v)
534+
else
535+
setu(prob, k)(prob, v)
476536
end
477537
end
478-
linres = lin_fun(u0, p, t)
479-
f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u = linres
480-
481-
nx, nu = size(f_u)
482-
nz = size(f_z, 2)
483-
ny = size(h_x, 1)
484-
485-
D = h_u
486-
487-
if isempty(g_z)
488-
A = f_x
489-
B = f_u
490-
C = h_x
491-
@assert iszero(g_x)
492-
@assert iszero(g_z)
493-
@assert iszero(g_u)
494-
else
495-
gz = lu(g_z; check = false)
496-
issuccess(gz) ||
497-
error("g_z not invertible, this indicates that the DAE is of index > 1.")
498-
gzgx = -(gz \ g_x)
499-
A = [f_x f_z
500-
gzgx*f_x gzgx*f_z]
501-
B = [f_u
502-
gzgx * f_u] # The cited paper has zeros in the bottom block, see derivation in https://github.com/SciML/ModelingToolkit.jl/pull/1691 for the correct formula
503-
504-
C = [h_x h_z]
505-
Bs = -(gz \ g_u) # This equation differ from the cited paper, the paper is likely wrong since their equaiton leads to a dimension mismatch.
506-
if !iszero(Bs)
507-
if !allow_input_derivatives
508-
der_inds = findall(vec(any(!=(0), Bs, dims = 1)))
509-
error("Input derivatives appeared in expressions (-g_z\\g_u != 0), the following inputs appeared differentiated: $(inputs(sys)[der_inds]). Call `linearize` with keyword argument `allow_input_derivatives = true` to allow this and have the returned `B` matrix be of double width ($(2nu)), where the last $nu inputs are the derivatives of the first $nu inputs.")
510-
end
511-
B = [B [zeros(nx, nu); Bs]]
512-
D = [D zeros(ny, nu)]
513-
end
538+
p = anydict(p)
539+
for (k, v) in p
540+
setu(prob, k)(prob, v)
514541
end
515-
516-
(; A, B, C, D)
542+
return solve(prob; allow_input_derivatives)
517543
end
518544

519545
function linearize(sys, inputs, outputs; op = Dict(), t = 0.0,

0 commit comments

Comments
 (0)