@@ -75,78 +75,125 @@ function linearization_function(sys::AbstractSystem, inputs,
75
75
76
76
ps = parameters (sys)
77
77
h = build_explicit_observed_function (sys, outputs; eval_expression, eval_module)
78
- lin_fun = let diff_idxs = diff_idxs,
79
- alge_idxs = alge_idxs,
80
- input_idxs = input_idxs,
81
- sts = unknowns (sys),
82
- fun = fun,
83
- prob = prob,
84
- sys_ps = p,
85
- h = h,
86
- integ_cache = (similar (u0)),
87
- chunk = ForwardDiff. Chunk (input_idxs),
88
- initializealg = initializealg,
89
- initialization_abstol = initialization_abstol,
90
- initialization_reltol = initialization_reltol,
91
- initialization_solver_alg = initialization_solver_alg,
92
- sys = sys
93
-
94
- function (u, p, t)
95
- if ! isa (p, MTKParameters)
96
- p = todict (p)
97
- newps = deepcopy (sys_ps)
98
- for (k, v) in p
99
- if is_parameter (sys, k)
100
- v = fixpoint_sub (v, p)
101
- setp (sys, k)(newps, v)
102
- end
103
- end
104
- p = newps
105
- end
106
78
107
- if u != = nothing # Handle systems without unknowns
108
- length (sts) == length (u) ||
109
- error (" Number of unknown variables ($(length (sts)) ) does not match the number of input unknowns ($(length (u)) )" )
110
-
111
- integ = MockIntegrator {true} (u, p, t, integ_cache)
112
- u, p, success = SciMLBase. get_initial_values (
113
- prob, integ, fun, initializealg, Val (true );
114
- abstol = initialization_abstol, reltol = initialization_reltol,
115
- nlsolve_alg = initialization_solver_alg)
116
- if ! success
117
- error (" Initialization algorithm $(initializealg) failed with `u = $u ` and `p = $p `." )
118
- end
119
- uf = SciMLBase. UJacobianWrapper (fun, t, p)
120
- fg_xz = ForwardDiff. jacobian (uf, u)
121
- h_xz = ForwardDiff. jacobian (
122
- let p = p, t = t
123
- xz -> h (xz, p, t)
124
- end , u)
125
- pf = SciMLBase. ParamJacobianWrapper (fun, t, u)
126
- fg_u = jacobian_wrt_vars (pf, p, input_idxs, chunk)
127
- else
128
- length (sts) == 0 ||
129
- error (" Number of unknown variables (0) does not match the number of input unknowns ($(length (u)) )" )
130
- fg_xz = zeros (0 , 0 )
131
- h_xz = fg_u = zeros (0 , length (inputs))
132
- end
133
- hp = let u = u, t = t
134
- _hp (p) = h (u, p, t)
135
- _hp
79
+ initialization_kwargs = (;
80
+ abstol = initialization_abstol, reltol = initialization_reltol,
81
+ nlsolve_alg = initialization_solver_alg)
82
+ lin_fun = LinearizationFunction (
83
+ diff_idxs, alge_idxs, input_idxs, length (unknowns (sys)), prob, h, similar (u0),
84
+ ForwardDiff. Chunk (input_idxs), initializealg, initialization_kwargs)
85
+ return lin_fun, sys
86
+ end
87
+
88
+ """
89
+ $(TYPEDEF)
90
+
91
+ A callable struct which linearizes a system.
92
+
93
+ # Fields
94
+
95
+ $(TYPEDFIELDS)
96
+ """
97
+ struct LinearizationFunction{
98
+ DI <: AbstractVector{Int} , AI <: AbstractVector{Int} , II, P <: ODEProblem ,
99
+ H, C, Ch, IA <: SciMLBase.DAEInitializationAlgorithm , IK}
100
+ """
101
+ The indexes of differential equations in the linearized system.
102
+ """
103
+ diff_idxs:: DI
104
+ """
105
+ The indexes of algebraic equations in the linearized system.
106
+ """
107
+ alge_idxs:: AI
108
+ """
109
+ The indexes of parameters in the linearized system which represent
110
+ input variables.
111
+ """
112
+ input_idxs:: II
113
+ """
114
+ The number of unknowns in the linearized system.
115
+ """
116
+ num_states:: Int
117
+ """
118
+ The `ODEProblem` of the linearized system.
119
+ """
120
+ prob:: P
121
+ """
122
+ A function which takes `(u, p, t)` and returns the outputs of the linearized system.
123
+ """
124
+ h:: H
125
+ """
126
+ Any required cache buffers.
127
+ """
128
+ caches:: C
129
+ # TODO : Use DI?
130
+ """
131
+ A `ForwardDiff.Chunk` for taking the jacobian with respect to the inputs.
132
+ """
133
+ chunk:: Ch
134
+ """
135
+ The initialization algorithm to use.
136
+ """
137
+ initializealg:: IA
138
+ """
139
+ Keyword arguments to be passed to `SciMLBase.get_initial_values`.
140
+ """
141
+ initialize_kwargs:: IK
142
+ end
143
+
144
+ function (linfun:: LinearizationFunction )(u, p, t)
145
+ if eltype (p) <: Pair
146
+ p = todict (p)
147
+ newps = copy (parameter_values (linfun. prob))
148
+ for (k, v) in p
149
+ if is_parameter (linfun, k)
150
+ v = fixpoint_sub (v, p)
151
+ setp (linfun, k)(newps, v)
136
152
end
137
- h_u = jacobian_wrt_vars (hp, p, input_idxs, chunk)
138
- (f_x = fg_xz[diff_idxs, diff_idxs],
139
- f_z = fg_xz[diff_idxs, alge_idxs],
140
- g_x = fg_xz[alge_idxs, diff_idxs],
141
- g_z = fg_xz[alge_idxs, alge_idxs],
142
- f_u = fg_u[diff_idxs, :],
143
- g_u = fg_u[alge_idxs, :],
144
- h_x = h_xz[:, diff_idxs],
145
- h_z = h_xz[:, alge_idxs],
146
- h_u = h_u)
147
153
end
154
+ p = newps
148
155
end
149
- return lin_fun, sys
156
+
157
+ fun = linfun. prob. f
158
+ if u != = nothing # Handle systems without unknowns
159
+ linfun. num_states == length (u) ||
160
+ error (" Number of unknown variables ($(linfun. num_states) ) does not match the number of input unknowns ($(length (u)) )" )
161
+ integ_cache = linfun. caches
162
+ integ = MockIntegrator {true} (u, p, t, integ_cache)
163
+ u, p, success = SciMLBase. get_initial_values (
164
+ linfun. prob, integ, fun, linfun. initializealg, Val (true );
165
+ linfun. initialize_kwargs... )
166
+ if ! success
167
+ error (" Initialization algorithm $(linfun. initializealg) failed with `u = $u ` and `p = $p `." )
168
+ end
169
+ uf = SciMLBase. UJacobianWrapper (fun, t, p)
170
+ fg_xz = ForwardDiff. jacobian (uf, u)
171
+ h_xz = ForwardDiff. jacobian (
172
+ let p = p, t = t, h = linfun. h
173
+ xz -> h (xz, p, t)
174
+ end , u)
175
+ pf = SciMLBase. ParamJacobianWrapper (fun, t, u)
176
+ fg_u = jacobian_wrt_vars (pf, p, linfun. input_idxs, linfun. chunk)
177
+ else
178
+ linfun. num_states == 0 ||
179
+ error (" Number of unknown variables (0) does not match the number of input unknowns ($(length (u)) )" )
180
+ fg_xz = zeros (0 , 0 )
181
+ h_xz = fg_u = zeros (0 , length (linfun. input_idxs))
182
+ end
183
+ hp = let u = u, t = t, h = linfun. h
184
+ _hp (p) = h (u, p, t)
185
+ _hp
186
+ end
187
+ h_u = jacobian_wrt_vars (hp, p, linfun. input_idxs, linfun. chunk)
188
+ (f_x = fg_xz[linfun. diff_idxs, linfun. diff_idxs],
189
+ f_z = fg_xz[linfun. diff_idxs, linfun. alge_idxs],
190
+ g_x = fg_xz[linfun. alge_idxs, linfun. diff_idxs],
191
+ g_z = fg_xz[linfun. alge_idxs, linfun. alge_idxs],
192
+ f_u = fg_u[linfun. diff_idxs, :],
193
+ g_u = fg_u[linfun. alge_idxs, :],
194
+ h_x = h_xz[:, linfun. diff_idxs],
195
+ h_z = h_xz[:, linfun. alge_idxs],
196
+ h_u = h_u)
150
197
end
151
198
152
199
"""
0 commit comments