@@ -25,6 +25,7 @@ The `simplified_sys` has undergone [`structural_simplify`](@ref) and had any occ
2525 - `simplify`: Apply simplification in tearing.
2626 - `initialize`: If true, a check is performed to ensure that the operating point is consistent (satisfies algebraic equations). If the op is not consistent, initialization is performed.
2727 - `initialization_solver_alg`: A NonlinearSolve algorithm to use for solving for a feasible set of state and algebraic variables that satisfies the specified operating point.
28+ - `autodiff`: An `ADType` supported by DifferentiationInterface.jl to use for calculating the necessary jacobians.
2829 - `kwargs`: Are passed on to `find_solvables!`
2930
3031See also [`linearize`](@ref) which provides a higher-level interface.
@@ -39,6 +40,7 @@ function linearization_function(sys::AbstractSystem, inputs,
3940 p = DiffEqBase. NullParameters (),
4041 zero_dummy_der = false ,
4142 initialization_solver_alg = TrustRegion (),
43+ autodiff = AutoForwardDiff (),
4244 eval_expression = false , eval_module = @__MODULE__ ,
4345 warn_initialize_determined = true ,
4446 guesses = Dict (),
@@ -82,13 +84,89 @@ function linearization_function(sys::AbstractSystem, inputs,
8284 initialization_kwargs = (;
8385 abstol = initialization_abstol, reltol = initialization_reltol,
8486 nlsolve_alg = initialization_solver_alg)
87+
88+ p = parameter_values (prob)
89+ t0 = current_time (prob)
90+ inputvals = [p[idx] for idx in input_idxs]
91+
92+ uf_fun = let fun = prob. f
93+ function uff (du, u, p, t)
94+ SciMLBase. UJacobianWrapper (fun, t, p)(du, u)
95+ end
96+ end
97+ uf_jac = PreparedJacobian {true} (uf_fun, similar (prob. u0), autodiff, prob. u0, DI. Constant (p), DI. Constant (t0))
98+ # observed function is a `GeneratedFunctionWrapper` with iip component
99+ h_jac = PreparedJacobian {true} (h, similar (prob. u0, size (outputs)), autodiff, prob. u0, DI. Constant (p), DI. Constant (t0))
100+ pf_fun = let fun = prob. f, setter = setp_oop (sys, input_idxs)
101+ function pff (du, input, u, p, t)
102+ p = setter (p, input)
103+ SciMLBase. ParamJacobianWrapper (fun, t, u)(du, p)
104+ end
105+ end
106+ pf_jac = PreparedJacobian {true} (pf_fun, similar (prob. u0), autodiff, inputvals, DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
107+ hp_fun = let fun = h, setter = setp_oop (sys, input_idxs)
108+ function hpf (du, input, u, p, t)
109+ p = setter (p, input)
110+ fun (du, u, p, t)
111+ return du
112+ end
113+ end
114+ hp_jac = PreparedJacobian {true} (hp_fun, similar (prob. u0, size (outputs)), autodiff, inputvals, DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
115+
85116 lin_fun = LinearizationFunction (
86117 diff_idxs, alge_idxs, input_idxs, length (unknowns (sys)),
87- prob, h, u0 === nothing ? nothing : similar (u0),
88- ForwardDiff . Chunk (input_idxs) , initializealg, initialization_kwargs)
118+ prob, h, u0 === nothing ? nothing : similar (u0), uf_jac, h_jac, pf_jac,
119+ hp_jac , initializealg, initialization_kwargs)
89120 return lin_fun, sys
90121end
91122
123+ """
124+ $(TYPEDEF)
125+
126+ Callable struct which stores a function and its prepared `DI.jacobian`. Calling with the
127+ appropriate arguments for DI returns the jacobian.
128+
129+ # Fields
130+
131+ $(TYPEDFIELDS)
132+ """
133+ struct PreparedJacobian{iip, P, F, B, A}
134+ """
135+ The preparation object.
136+ """
137+ prep:: P
138+ """
139+ The function whose jacobian is calculated.
140+ """
141+ f:: F
142+ """
143+ Buffer for in-place functions.
144+ """
145+ buf:: B
146+ """
147+ ADType to use for differentiation.
148+ """
149+ autodiff:: A
150+ end
151+
152+ function PreparedJacobian {true} (f, buf, autodiff, args... )
153+ prep = DI. prepare_jacobian (f, buf, autodiff, args... )
154+ return PreparedJacobian {true, typeof(prep), typeof(f), typeof(buf), typeof(autodiff)} (prep, f, buf, autodiff)
155+ end
156+
157+ function PreparedJacobian {false} (f, autodiff, args... )
158+ prep = DI. prepare_jacobian (f, autodiff, args... )
159+ return PreparedJacobian {true, typeof(prep), typeof(f), Nothing, typeof(autodiff)} (prep, f, nothing )
160+ end
161+
162+ function (pj:: PreparedJacobian{true} )(args... )
163+ DI. jacobian (pj. f, pj. buf, pj. prep, pj. autodiff, args... )
164+ end
165+
166+ function (pj:: PreparedJacobian{false} )(args... )
167+ DI. jacobian (pj. f, pj. prep, pj. autodiff, args... )
168+ end
169+
92170"""
93171 $(TYPEDEF)
94172
@@ -100,7 +178,7 @@ $(TYPEDFIELDS)
100178"""
101179struct LinearizationFunction{
102180 DI <: AbstractVector{Int} , AI <: AbstractVector{Int} , II, P <: ODEProblem ,
103- H, C, Ch , IA <: SciMLBase.DAEInitializationAlgorithm , IK}
181+ H, C, J1, J2, J3, J4 , IA <: SciMLBase.DAEInitializationAlgorithm , IK}
104182 """
105183 The indexes of differential equations in the linearized system.
106184 """
@@ -130,11 +208,22 @@ struct LinearizationFunction{
130208 Any required cache buffers.
131209 """
132210 caches:: C
133- # TODO : Use DI?
134211 """
135- A `ForwardDiff.Chunk` for taking the jacobian with respect to the inputs.
212+ `PreparedJacobian` for calculating jacobian of `prob.f` w.r.t. `u`
213+ """
214+ uf_jac:: J1
215+ """
216+ `PreparedJacobian` for calculating jacobian of `h` w.r.t. `u`
136217 """
137- chunk:: Ch
218+ h_jac:: J2
219+ """
220+ `PreparedJacobian` for calculating jacobian of `prob.f` w.r.t. `p`
221+ """
222+ pf_jac:: J3
223+ """
224+ `PreparedJacobian` for calculating jacobian of `h` w.r.t. `p`
225+ """
226+ hp_jac:: J4
138227 """
139228 The initialization algorithm to use.
140229 """
@@ -188,25 +277,16 @@ function (linfun::LinearizationFunction)(u, p, t)
188277 if ! success
189278 error (" Initialization algorithm $(linfun. initializealg) failed with `u = $u ` and `p = $p `." )
190279 end
191- uf = SciMLBase. UJacobianWrapper (fun, t, p)
192- fg_xz = ForwardDiff. jacobian (uf, u)
193- h_xz = ForwardDiff. jacobian (
194- let p = p, t = t, h = linfun. h
195- xz -> h (xz, p, t)
196- end , u)
197- pf = SciMLBase. ParamJacobianWrapper (fun, t, u)
198- fg_u = jacobian_wrt_vars (pf, p, linfun. input_idxs, linfun. chunk)
280+ fg_xz = linfun. uf_jac (u, DI. Constant (p), DI. Constant (t))
281+ h_xz = linfun. h_jac (u, DI. Constant (p), DI. Constant (t))
282+ fg_u = linfun. pf_jac ([p[idx] for idx in linfun. input_idxs], DI. Constant (u), DI. Constant (p), DI. Constant (t))
199283 else
200284 linfun. num_states == 0 ||
201285 error (" Number of unknown variables (0) does not match the number of input unknowns ($(length (u)) )" )
202286 fg_xz = zeros (0 , 0 )
203287 h_xz = fg_u = zeros (0 , length (linfun. input_idxs))
204288 end
205- hp = let u = u, t = t, h = linfun. h
206- _hp (p) = h (u, p, t)
207- _hp
208- end
209- h_u = jacobian_wrt_vars (hp, p, linfun. input_idxs, linfun. chunk)
289+ h_u = linfun. hp_jac ([p[idx] for idx in linfun. input_idxs], DI. Constant (u), DI. Constant (p), DI. Constant (t))
210290 (f_x = fg_xz[linfun. diff_idxs, linfun. diff_idxs],
211291 f_z = fg_xz[linfun. diff_idxs, linfun. alge_idxs],
212292 g_x = fg_xz[linfun. alge_idxs, linfun. diff_idxs],
0 commit comments