Skip to content

Commit 382ce9b

Browse files
refactor: use new LinearizationFunction instead of closure in linearization_function
1 parent bfae1d2 commit 382ce9b

File tree

1 file changed

+115
-68
lines changed

1 file changed

+115
-68
lines changed

src/linearization.jl

Lines changed: 115 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -75,78 +75,125 @@ function linearization_function(sys::AbstractSystem, inputs,
7575

7676
ps = parameters(sys)
7777
h = build_explicit_observed_function(sys, outputs; eval_expression, eval_module)
78-
lin_fun = let diff_idxs = diff_idxs,
79-
alge_idxs = alge_idxs,
80-
input_idxs = input_idxs,
81-
sts = unknowns(sys),
82-
fun = fun,
83-
prob = prob,
84-
sys_ps = p,
85-
h = h,
86-
integ_cache = (similar(u0)),
87-
chunk = ForwardDiff.Chunk(input_idxs),
88-
initializealg = initializealg,
89-
initialization_abstol = initialization_abstol,
90-
initialization_reltol = initialization_reltol,
91-
initialization_solver_alg = initialization_solver_alg,
92-
sys = sys
93-
94-
function (u, p, t)
95-
if !isa(p, MTKParameters)
96-
p = todict(p)
97-
newps = deepcopy(sys_ps)
98-
for (k, v) in p
99-
if is_parameter(sys, k)
100-
v = fixpoint_sub(v, p)
101-
setp(sys, k)(newps, v)
102-
end
103-
end
104-
p = newps
105-
end
10678

107-
if u !== nothing # Handle systems without unknowns
108-
length(sts) == length(u) ||
109-
error("Number of unknown variables ($(length(sts))) does not match the number of input unknowns ($(length(u)))")
110-
111-
integ = MockIntegrator{true}(u, p, t, integ_cache)
112-
u, p, success = SciMLBase.get_initial_values(
113-
prob, integ, fun, initializealg, Val(true);
114-
abstol = initialization_abstol, reltol = initialization_reltol,
115-
nlsolve_alg = initialization_solver_alg)
116-
if !success
117-
error("Initialization algorithm $(initializealg) failed with `u = $u` and `p = $p`.")
118-
end
119-
uf = SciMLBase.UJacobianWrapper(fun, t, p)
120-
fg_xz = ForwardDiff.jacobian(uf, u)
121-
h_xz = ForwardDiff.jacobian(
122-
let p = p, t = t
123-
xz -> h(xz, p, t)
124-
end, u)
125-
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
126-
fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk)
127-
else
128-
length(sts) == 0 ||
129-
error("Number of unknown variables (0) does not match the number of input unknowns ($(length(u)))")
130-
fg_xz = zeros(0, 0)
131-
h_xz = fg_u = zeros(0, length(inputs))
132-
end
133-
hp = let u = u, t = t
134-
_hp(p) = h(u, p, t)
135-
_hp
79+
initialization_kwargs = (;
80+
abstol = initialization_abstol, reltol = initialization_reltol,
81+
nlsolve_alg = initialization_solver_alg)
82+
lin_fun = LinearizationFunction(
83+
diff_idxs, alge_idxs, input_idxs, length(unknowns(sys)), prob, h, similar(u0),
84+
ForwardDiff.Chunk(input_idxs), initializealg, initialization_kwargs)
85+
return lin_fun, sys
86+
end
87+
88+
"""
89+
$(TYPEDEF)
90+
91+
A callable struct which linearizes a system.
92+
93+
# Fields
94+
95+
$(TYPEDFIELDS)
96+
"""
97+
struct LinearizationFunction{
98+
DI <: AbstractVector{Int}, AI <: AbstractVector{Int}, II, P <: ODEProblem,
99+
H, C, Ch, IA <: SciMLBase.DAEInitializationAlgorithm, IK}
100+
"""
101+
The indexes of differential equations in the linearized system.
102+
"""
103+
diff_idxs::DI
104+
"""
105+
The indexes of algebraic equations in the linearized system.
106+
"""
107+
alge_idxs::AI
108+
"""
109+
The indexes of parameters in the linearized system which represent
110+
input variables.
111+
"""
112+
input_idxs::II
113+
"""
114+
The number of unknowns in the linearized system.
115+
"""
116+
num_states::Int
117+
"""
118+
The `ODEProblem` of the linearized system.
119+
"""
120+
prob::P
121+
"""
122+
A function which takes `(u, p, t)` and returns the outputs of the linearized system.
123+
"""
124+
h::H
125+
"""
126+
Any required cache buffers.
127+
"""
128+
caches::C
129+
# TODO: Use DI?
130+
"""
131+
A `ForwardDiff.Chunk` for taking the jacobian with respect to the inputs.
132+
"""
133+
chunk::Ch
134+
"""
135+
The initialization algorithm to use.
136+
"""
137+
initializealg::IA
138+
"""
139+
Keyword arguments to be passed to `SciMLBase.get_initial_values`.
140+
"""
141+
initialize_kwargs::IK
142+
end
143+
144+
function (linfun::LinearizationFunction)(u, p, t)
145+
if eltype(p) <: Pair
146+
p = todict(p)
147+
newps = copy(parameter_values(linfun.prob))
148+
for (k, v) in p
149+
if is_parameter(linfun, k)
150+
v = fixpoint_sub(v, p)
151+
setp(linfun, k)(newps, v)
136152
end
137-
h_u = jacobian_wrt_vars(hp, p, input_idxs, chunk)
138-
(f_x = fg_xz[diff_idxs, diff_idxs],
139-
f_z = fg_xz[diff_idxs, alge_idxs],
140-
g_x = fg_xz[alge_idxs, diff_idxs],
141-
g_z = fg_xz[alge_idxs, alge_idxs],
142-
f_u = fg_u[diff_idxs, :],
143-
g_u = fg_u[alge_idxs, :],
144-
h_x = h_xz[:, diff_idxs],
145-
h_z = h_xz[:, alge_idxs],
146-
h_u = h_u)
147153
end
154+
p = newps
148155
end
149-
return lin_fun, sys
156+
157+
fun = linfun.prob.f
158+
if u !== nothing # Handle systems without unknowns
159+
linfun.num_states == length(u) ||
160+
error("Number of unknown variables ($(linfun.num_states)) does not match the number of input unknowns ($(length(u)))")
161+
integ_cache = linfun.caches
162+
integ = MockIntegrator{true}(u, p, t, integ_cache)
163+
u, p, success = SciMLBase.get_initial_values(
164+
linfun.prob, integ, fun, linfun.initializealg, Val(true);
165+
linfun.initialize_kwargs...)
166+
if !success
167+
error("Initialization algorithm $(linfun.initializealg) failed with `u = $u` and `p = $p`.")
168+
end
169+
uf = SciMLBase.UJacobianWrapper(fun, t, p)
170+
fg_xz = ForwardDiff.jacobian(uf, u)
171+
h_xz = ForwardDiff.jacobian(
172+
let p = p, t = t, h = linfun.h
173+
xz -> h(xz, p, t)
174+
end, u)
175+
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
176+
fg_u = jacobian_wrt_vars(pf, p, linfun.input_idxs, linfun.chunk)
177+
else
178+
linfun.num_states == 0 ||
179+
error("Number of unknown variables (0) does not match the number of input unknowns ($(length(u)))")
180+
fg_xz = zeros(0, 0)
181+
h_xz = fg_u = zeros(0, length(linfun.input_idxs))
182+
end
183+
hp = let u = u, t = t, h = linfun.h
184+
_hp(p) = h(u, p, t)
185+
_hp
186+
end
187+
h_u = jacobian_wrt_vars(hp, p, linfun.input_idxs, linfun.chunk)
188+
(f_x = fg_xz[linfun.diff_idxs, linfun.diff_idxs],
189+
f_z = fg_xz[linfun.diff_idxs, linfun.alge_idxs],
190+
g_x = fg_xz[linfun.alge_idxs, linfun.diff_idxs],
191+
g_z = fg_xz[linfun.alge_idxs, linfun.alge_idxs],
192+
f_u = fg_u[linfun.diff_idxs, :],
193+
g_u = fg_u[linfun.alge_idxs, :],
194+
h_x = h_xz[:, linfun.diff_idxs],
195+
h_z = h_xz[:, linfun.alge_idxs],
196+
h_u = h_u)
150197
end
151198

152199
"""

0 commit comments

Comments
 (0)