@@ -25,6 +25,7 @@ The `simplified_sys` has undergone [`structural_simplify`](@ref) and had any occ
2525 - `simplify`: Apply simplification in tearing.
2626 - `initialize`: If true, a check is performed to ensure that the operating point is consistent (satisfies algebraic equations). If the op is not consistent, initialization is performed.
2727 - `initialization_solver_alg`: A NonlinearSolve algorithm to use for solving for a feasible set of state and algebraic variables that satisfies the specified operating point.
28+ - `autodiff`: An `ADType` supported by DifferentiationInterface.jl to use for calculating the necessary jacobians. Defaults to using `AutoForwardDiff()`
2829 - `kwargs`: Are passed on to `find_solvables!`
2930
3031See also [`linearize`](@ref) which provides a higher-level interface.
@@ -39,6 +40,7 @@ function linearization_function(sys::AbstractSystem, inputs,
3940 p = DiffEqBase. NullParameters (),
4041 zero_dummy_der = false ,
4142 initialization_solver_alg = TrustRegion (),
43+ autodiff = AutoForwardDiff (),
4244 eval_expression = false , eval_module = @__MODULE__ ,
4345 warn_initialize_determined = true ,
4446 guesses = Dict (),
@@ -82,13 +84,104 @@ function linearization_function(sys::AbstractSystem, inputs,
8284 initialization_kwargs = (;
8385 abstol = initialization_abstol, reltol = initialization_reltol,
8486 nlsolve_alg = initialization_solver_alg)
87+
88+ p = parameter_values (prob)
89+ t0 = current_time (prob)
90+ inputvals = [p[idx] for idx in input_idxs]
91+
92+ hp_fun = let fun = h, setter = setp_oop (sys, input_idxs)
93+ function hpf (du, input, u, p, t)
94+ p = setter (p, input)
95+ fun (du, u, p, t)
96+ return du
97+ end
98+ end
99+ if u0 === nothing
100+ uf_jac = h_jac = pf_jac = nothing
101+ T = p isa MTKParameters ? eltype (p. tunable) : eltype (p)
102+ hp_jac = PreparedJacobian {true} (
103+ hp_fun, zeros (T, size (outputs)), autodiff, inputvals,
104+ DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
105+ else
106+ uf_fun = let fun = prob. f
107+ function uff (du, u, p, t)
108+ SciMLBase. UJacobianWrapper (fun, t, p)(du, u)
109+ end
110+ end
111+ uf_jac = PreparedJacobian {true} (
112+ uf_fun, similar (prob. u0), autodiff, prob. u0, DI. Constant (p), DI. Constant (t0))
113+ # observed function is a `GeneratedFunctionWrapper` with iip component
114+ h_jac = PreparedJacobian {true} (h, similar (prob. u0, size (outputs)), autodiff,
115+ prob. u0, DI. Constant (p), DI. Constant (t0))
116+ pf_fun = let fun = prob. f, setter = setp_oop (sys, input_idxs)
117+ function pff (du, input, u, p, t)
118+ p = setter (p, input)
119+ SciMLBase. ParamJacobianWrapper (fun, t, u)(du, p)
120+ end
121+ end
122+ pf_jac = PreparedJacobian {true} (pf_fun, similar (prob. u0), autodiff, inputvals,
123+ DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
124+ hp_jac = PreparedJacobian {true} (
125+ hp_fun, similar (prob. u0, size (outputs)), autodiff, inputvals,
126+ DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
127+ end
128+
85129 lin_fun = LinearizationFunction (
86130 diff_idxs, alge_idxs, input_idxs, length (unknowns (sys)),
87- prob, h, u0 === nothing ? nothing : similar (u0),
88- ForwardDiff . Chunk (input_idxs) , initializealg, initialization_kwargs)
131+ prob, h, u0 === nothing ? nothing : similar (u0), uf_jac, h_jac, pf_jac,
132+ hp_jac , initializealg, initialization_kwargs)
89133 return lin_fun, sys
90134end
91135
136+ """
137+ $(TYPEDEF)
138+
139+ Callable struct which stores a function and its prepared `DI.jacobian`. Calling with the
140+ appropriate arguments for DI returns the jacobian.
141+
142+ # Fields
143+
144+ $(TYPEDFIELDS)
145+ """
146+ struct PreparedJacobian{iip, P, F, B, A}
147+ """
148+ The preparation object.
149+ """
150+ prep:: P
151+ """
152+ The function whose jacobian is calculated.
153+ """
154+ f:: F
155+ """
156+ Buffer for in-place functions.
157+ """
158+ buf:: B
159+ """
160+ ADType to use for differentiation.
161+ """
162+ autodiff:: A
163+ end
164+
165+ function PreparedJacobian {true} (f, buf, autodiff, args... )
166+ prep = DI. prepare_jacobian (f, buf, autodiff, args... )
167+ return PreparedJacobian {true, typeof(prep), typeof(f), typeof(buf), typeof(autodiff)} (
168+ prep, f, buf, autodiff)
169+ end
170+
171+ function PreparedJacobian {false} (f, autodiff, args... )
172+ prep = DI. prepare_jacobian (f, autodiff, args... )
173+ return PreparedJacobian {true, typeof(prep), typeof(f), Nothing, typeof(autodiff)} (
174+ prep, f, nothing )
175+ end
176+
177+ function (pj:: PreparedJacobian{true} )(args... )
178+ DI. jacobian (pj. f, pj. buf, pj. prep, pj. autodiff, args... )
179+ end
180+
181+ function (pj:: PreparedJacobian{false} )(args... )
182+ DI. jacobian (pj. f, pj. prep, pj. autodiff, args... )
183+ end
184+
92185"""
93186 $(TYPEDEF)
94187
@@ -100,7 +193,7 @@ $(TYPEDFIELDS)
100193"""
101194struct LinearizationFunction{
102195 DI <: AbstractVector{Int} , AI <: AbstractVector{Int} , II, P <: ODEProblem ,
103- H, C, Ch , IA <: SciMLBase.DAEInitializationAlgorithm , IK}
196+ H, C, J1, J2, J3, J4 , IA <: SciMLBase.DAEInitializationAlgorithm , IK}
104197 """
105198 The indexes of differential equations in the linearized system.
106199 """
@@ -130,11 +223,22 @@ struct LinearizationFunction{
130223 Any required cache buffers.
131224 """
132225 caches:: C
133- # TODO : Use DI?
134226 """
135- A `ForwardDiff.Chunk` for taking the jacobian with respect to the inputs.
227+ `PreparedJacobian` for calculating jacobian of `prob.f` w.r.t. `u`
228+ """
229+ uf_jac:: J1
230+ """
231+ `PreparedJacobian` for calculating jacobian of `h` w.r.t. `u`
136232 """
137- chunk:: Ch
233+ h_jac:: J2
234+ """
235+ `PreparedJacobian` for calculating jacobian of `prob.f` w.r.t. `p`
236+ """
237+ pf_jac:: J3
238+ """
239+ `PreparedJacobian` for calculating jacobian of `h` w.r.t. `p`
240+ """
241+ hp_jac:: J4
138242 """
139243 The initialization algorithm to use.
140244 """
@@ -188,25 +292,18 @@ function (linfun::LinearizationFunction)(u, p, t)
188292 if ! success
189293 error (" Initialization algorithm $(linfun. initializealg) failed with `u = $u ` and `p = $p `." )
190294 end
191- uf = SciMLBase. UJacobianWrapper (fun, t, p)
192- fg_xz = ForwardDiff. jacobian (uf, u)
193- h_xz = ForwardDiff. jacobian (
194- let p = p, t = t, h = linfun. h
195- xz -> h (xz, p, t)
196- end , u)
197- pf = SciMLBase. ParamJacobianWrapper (fun, t, u)
198- fg_u = jacobian_wrt_vars (pf, p, linfun. input_idxs, linfun. chunk)
295+ fg_xz = linfun. uf_jac (u, DI. Constant (p), DI. Constant (t))
296+ h_xz = linfun. h_jac (u, DI. Constant (p), DI. Constant (t))
297+ fg_u = linfun. pf_jac ([p[idx] for idx in linfun. input_idxs],
298+ DI. Constant (u), DI. Constant (p), DI. Constant (t))
199299 else
200300 linfun. num_states == 0 ||
201301 error (" Number of unknown variables (0) does not match the number of input unknowns ($(length (u)) )" )
202302 fg_xz = zeros (0 , 0 )
203303 h_xz = fg_u = zeros (0 , length (linfun. input_idxs))
204304 end
205- hp = let u = u, t = t, h = linfun. h
206- _hp (p) = h (u, p, t)
207- _hp
208- end
209- h_u = jacobian_wrt_vars (hp, p, linfun. input_idxs, linfun. chunk)
305+ h_u = linfun. hp_jac ([p[idx] for idx in linfun. input_idxs],
306+ DI. Constant (u), DI. Constant (p), DI. Constant (t))
210307 (f_x = fg_xz[linfun. diff_idxs, linfun. diff_idxs],
211308 f_z = fg_xz[linfun. diff_idxs, linfun. alge_idxs],
212309 g_x = fg_xz[linfun. alge_idxs, linfun. diff_idxs],
0 commit comments