@@ -75,78 +75,125 @@ function linearization_function(sys::AbstractSystem, inputs,
7575
7676 ps = parameters (sys)
7777 h = build_explicit_observed_function (sys, outputs; eval_expression, eval_module)
78- lin_fun = let diff_idxs = diff_idxs,
79- alge_idxs = alge_idxs,
80- input_idxs = input_idxs,
81- sts = unknowns (sys),
82- fun = fun,
83- prob = prob,
84- sys_ps = p,
85- h = h,
86- integ_cache = (similar (u0)),
87- chunk = ForwardDiff. Chunk (input_idxs),
88- initializealg = initializealg,
89- initialization_abstol = initialization_abstol,
90- initialization_reltol = initialization_reltol,
91- initialization_solver_alg = initialization_solver_alg,
92- sys = sys
93-
94- function (u, p, t)
95- if ! isa (p, MTKParameters)
96- p = todict (p)
97- newps = deepcopy (sys_ps)
98- for (k, v) in p
99- if is_parameter (sys, k)
100- v = fixpoint_sub (v, p)
101- setp (sys, k)(newps, v)
102- end
103- end
104- p = newps
105- end
10678
107- if u != = nothing # Handle systems without unknowns
108- length (sts) == length (u) ||
109- error (" Number of unknown variables ($(length (sts)) ) does not match the number of input unknowns ($(length (u)) )" )
110-
111- integ = MockIntegrator {true} (u, p, t, integ_cache)
112- u, p, success = SciMLBase. get_initial_values (
113- prob, integ, fun, initializealg, Val (true );
114- abstol = initialization_abstol, reltol = initialization_reltol,
115- nlsolve_alg = initialization_solver_alg)
116- if ! success
117- error (" Initialization algorithm $(initializealg) failed with `u = $u ` and `p = $p `." )
118- end
119- uf = SciMLBase. UJacobianWrapper (fun, t, p)
120- fg_xz = ForwardDiff. jacobian (uf, u)
121- h_xz = ForwardDiff. jacobian (
122- let p = p, t = t
123- xz -> h (xz, p, t)
124- end , u)
125- pf = SciMLBase. ParamJacobianWrapper (fun, t, u)
126- fg_u = jacobian_wrt_vars (pf, p, input_idxs, chunk)
127- else
128- length (sts) == 0 ||
129- error (" Number of unknown variables (0) does not match the number of input unknowns ($(length (u)) )" )
130- fg_xz = zeros (0 , 0 )
131- h_xz = fg_u = zeros (0 , length (inputs))
132- end
133- hp = let u = u, t = t
134- _hp (p) = h (u, p, t)
135- _hp
79+ initialization_kwargs = (;
80+ abstol = initialization_abstol, reltol = initialization_reltol,
81+ nlsolve_alg = initialization_solver_alg)
82+ lin_fun = LinearizationFunction (
83+ diff_idxs, alge_idxs, input_idxs, length (unknowns (sys)), prob, h, similar (u0),
84+ ForwardDiff. Chunk (input_idxs), initializealg, initialization_kwargs)
85+ return lin_fun, sys
86+ end
87+
88+ """
89+ $(TYPEDEF)
90+
91+ A callable struct which linearizes a system.
92+
93+ # Fields
94+
95+ $(TYPEDFIELDS)
96+ """
97+ struct LinearizationFunction{
98+ DI <: AbstractVector{Int} , AI <: AbstractVector{Int} , II, P <: ODEProblem ,
99+ H, C, Ch, IA <: SciMLBase.DAEInitializationAlgorithm , IK}
100+ """
101+ The indexes of differential equations in the linearized system.
102+ """
103+ diff_idxs:: DI
104+ """
105+ The indexes of algebraic equations in the linearized system.
106+ """
107+ alge_idxs:: AI
108+ """
109+ The indexes of parameters in the linearized system which represent
110+ input variables.
111+ """
112+ input_idxs:: II
113+ """
114+ The number of unknowns in the linearized system.
115+ """
116+ num_states:: Int
117+ """
118+ The `ODEProblem` of the linearized system.
119+ """
120+ prob:: P
121+ """
122+ A function which takes `(u, p, t)` and returns the outputs of the linearized system.
123+ """
124+ h:: H
125+ """
126+ Any required cache buffers.
127+ """
128+ caches:: C
129+ # TODO : Use DI?
130+ """
131+ A `ForwardDiff.Chunk` for taking the jacobian with respect to the inputs.
132+ """
133+ chunk:: Ch
134+ """
135+ The initialization algorithm to use.
136+ """
137+ initializealg:: IA
138+ """
139+ Keyword arguments to be passed to `SciMLBase.get_initial_values`.
140+ """
141+ initialize_kwargs:: IK
142+ end
143+
144+ function (linfun:: LinearizationFunction )(u, p, t)
145+ if eltype (p) <: Pair
146+ p = todict (p)
147+ newps = copy (parameter_values (linfun. prob))
148+ for (k, v) in p
149+ if is_parameter (sys, k)
150+ v = fixpoint_sub (v, p)
151+ setp (sys, k)(newps, v)
136152 end
137- h_u = jacobian_wrt_vars (hp, p, input_idxs, chunk)
138- (f_x = fg_xz[diff_idxs, diff_idxs],
139- f_z = fg_xz[diff_idxs, alge_idxs],
140- g_x = fg_xz[alge_idxs, diff_idxs],
141- g_z = fg_xz[alge_idxs, alge_idxs],
142- f_u = fg_u[diff_idxs, :],
143- g_u = fg_u[alge_idxs, :],
144- h_x = h_xz[:, diff_idxs],
145- h_z = h_xz[:, alge_idxs],
146- h_u = h_u)
147153 end
154+ p = newps
148155 end
149- return lin_fun, sys
156+
157+ fun = linfun. prob. f
158+ if u != = nothing # Handle systems without unknowns
159+ linfun. num_states == length (u) ||
160+ error (" Number of unknown variables ($(linfun. num_states) ) does not match the number of input unknowns ($(length (u)) )" )
161+ integ_cache = linfun. caches
162+ integ = MockIntegrator {true} (u, p, t, integ_cache)
163+ u, p, success = SciMLBase. get_initial_values (
164+ linfun. prob, integ, fun, linfun. initializealg, Val (true );
165+ linfun. initialize_kwargs... )
166+ if ! success
167+ error (" Initialization algorithm $(linfun. initializealg) failed with `u = $u ` and `p = $p `." )
168+ end
169+ uf = SciMLBase. UJacobianWrapper (fun, t, p)
170+ fg_xz = ForwardDiff. jacobian (uf, u)
171+ h_xz = ForwardDiff. jacobian (
172+ let p = p, t = t, h = linfun. h
173+ xz -> h (xz, p, t)
174+ end , u)
175+ pf = SciMLBase. ParamJacobianWrapper (fun, t, u)
176+ fg_u = jacobian_wrt_vars (pf, p, linfun. input_idxs, linfun. chunk)
177+ else
178+ linfun. num_states == 0 ||
179+ error (" Number of unknown variables (0) does not match the number of input unknowns ($(length (u)) )" )
180+ fg_xz = zeros (0 , 0 )
181+ h_xz = fg_u = zeros (0 , length (inputs))
182+ end
183+ hp = let u = u, t = t, h = linfun. h
184+ _hp (p) = h (u, p, t)
185+ _hp
186+ end
187+ h_u = jacobian_wrt_vars (hp, p, linfun. input_idxs, linfun. chunk)
188+ (f_x = fg_xz[linfun. diff_idxs, linfun. diff_idxs],
189+ f_z = fg_xz[linfun. diff_idxs, linfun. alge_idxs],
190+ g_x = fg_xz[linfun. alge_idxs, linfun. diff_idxs],
191+ g_z = fg_xz[linfun. alge_idxs, linfun. alge_idxs],
192+ f_u = fg_u[linfun. diff_idxs, :],
193+ g_u = fg_u[linfun. alge_idxs, :],
194+ h_x = h_xz[:, linfun. diff_idxs],
195+ h_z = h_xz[:, linfun. alge_idxs],
196+ h_u = h_u)
150197end
151198
152199"""
0 commit comments