@@ -25,6 +25,7 @@ The `simplified_sys` has undergone [`structural_simplify`](@ref) and had any occ
2525  - `simplify`: Apply simplification in tearing. 
2626  - `initialize`: If true, a check is performed to ensure that the operating point is consistent (satisfies algebraic equations). If the op is not consistent, initialization is performed. 
2727  - `initialization_solver_alg`: A NonlinearSolve algorithm to use for solving for a feasible set of state and algebraic variables that satisfies the specified operating point. 
28+   - `autodiff`: An `ADType` supported by DifferentiationInterface.jl to use for calculating the necessary jacobians. Defaults to using `AutoForwardDiff()` 
2829  - `kwargs`: Are passed on to `find_solvables!` 
2930
3031See also [`linearize`](@ref) which provides a higher-level interface. 
@@ -39,6 +40,7 @@ function linearization_function(sys::AbstractSystem, inputs,
3940        p =  DiffEqBase. NullParameters (),
4041        zero_dummy_der =  false ,
4142        initialization_solver_alg =  TrustRegion (),
43+         autodiff =  AutoForwardDiff (),
4244        eval_expression =  false , eval_module =  @__MODULE__ ,
4345        warn_initialize_determined =  true ,
4446        guesses =  Dict (),
@@ -82,13 +84,104 @@ function linearization_function(sys::AbstractSystem, inputs,
8284    initialization_kwargs =  (;
8385        abstol =  initialization_abstol, reltol =  initialization_reltol,
8486        nlsolve_alg =  initialization_solver_alg)
87+ 
88+     p =  parameter_values (prob)
89+     t0 =  current_time (prob)
90+     inputvals =  [p[idx] for  idx in  input_idxs]
91+ 
92+     hp_fun =  let  fun =  h, setter =  setp_oop (sys, input_idxs)
93+         function  hpf (du, input, u, p, t)
94+             p =  setter (p, input)
95+             fun (du, u, p, t)
96+             return  du
97+         end 
98+     end 
99+     if  u0 ===  nothing 
100+         uf_jac =  h_jac =  pf_jac =  nothing 
101+         T =  p isa  MTKParameters ?  eltype (p. tunable) :  eltype (p)
102+         hp_jac =  PreparedJacobian {true} (
103+             hp_fun, zeros (T, size (outputs)), autodiff, inputvals,
104+             DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
105+     else 
106+         uf_fun =  let  fun =  prob. f
107+             function  uff (du, u, p, t)
108+                 SciMLBase. UJacobianWrapper (fun, t, p)(du, u)
109+             end 
110+         end 
111+         uf_jac =  PreparedJacobian {true} (
112+             uf_fun, similar (prob. u0), autodiff, prob. u0, DI. Constant (p), DI. Constant (t0))
113+         #  observed function is a `GeneratedFunctionWrapper` with iip component
114+         h_jac =  PreparedJacobian {true} (h, similar (prob. u0, size (outputs)), autodiff,
115+             prob. u0, DI. Constant (p), DI. Constant (t0))
116+         pf_fun =  let  fun =  prob. f, setter =  setp_oop (sys, input_idxs)
117+             function  pff (du, input, u, p, t)
118+                 p =  setter (p, input)
119+                 SciMLBase. ParamJacobianWrapper (fun, t, u)(du, p)
120+             end 
121+         end 
122+         pf_jac =  PreparedJacobian {true} (pf_fun, similar (prob. u0), autodiff, inputvals,
123+             DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
124+         hp_jac =  PreparedJacobian {true} (
125+             hp_fun, similar (prob. u0, size (outputs)), autodiff, inputvals,
126+             DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
127+     end 
128+ 
85129    lin_fun =  LinearizationFunction (
86130        diff_idxs, alge_idxs, input_idxs, length (unknowns (sys)),
87-         prob, h, u0 ===  nothing  ?  nothing  :  similar (u0),
88-         ForwardDiff . Chunk (input_idxs) , initializealg, initialization_kwargs)
131+         prob, h, u0 ===  nothing  ?  nothing  :  similar (u0), uf_jac, h_jac, pf_jac, 
132+         hp_jac , initializealg, initialization_kwargs)
89133    return  lin_fun, sys
90134end 
91135
136+ """ 
137+     $(TYPEDEF)  
138+ 
139+ Callable struct which stores a function and its prepared `DI.jacobian`. Calling with the 
140+ appropriate arguments for DI returns the jacobian. 
141+ 
142+ # Fields 
143+ 
144+ $(TYPEDFIELDS) 
145+ """ 
146+ struct  PreparedJacobian{iip, P, F, B, A}
147+     """ 
148+     The preparation object. 
149+     """  
150+     prep:: P 
151+     """ 
152+     The function whose jacobian is calculated. 
153+     """  
154+     f:: F 
155+     """ 
156+     Buffer for in-place functions. 
157+     """  
158+     buf:: B 
159+     """ 
160+     ADType to use for differentiation. 
161+     """  
162+     autodiff:: A 
163+ end 
164+ 
165+ function  PreparedJacobian {true} (f, buf, autodiff, args... )
166+     prep =  DI. prepare_jacobian (f, buf, autodiff, args... )
167+     return  PreparedJacobian {true, typeof(prep), typeof(f), typeof(buf), typeof(autodiff)} (
168+         prep, f, buf, autodiff)
169+ end 
170+ 
171+ function  PreparedJacobian {false} (f, autodiff, args... )
172+     prep =  DI. prepare_jacobian (f, autodiff, args... )
173+     return  PreparedJacobian {true, typeof(prep), typeof(f), Nothing, typeof(autodiff)} (
174+         prep, f, nothing )
175+ end 
176+ 
177+ function  (pj:: PreparedJacobian{true} )(args... )
178+     DI. jacobian (pj. f, pj. buf, pj. prep, pj. autodiff, args... )
179+ end 
180+ 
181+ function  (pj:: PreparedJacobian{false} )(args... )
182+     DI. jacobian (pj. f, pj. prep, pj. autodiff, args... )
183+ end 
184+ 
92185""" 
93186    $(TYPEDEF)  
94187
@@ -100,7 +193,7 @@ $(TYPEDFIELDS)
100193""" 
101194struct  LinearizationFunction{
102195    DI <:  AbstractVector{Int} , AI <:  AbstractVector{Int} , II, P <:  ODEProblem ,
103-     H, C, Ch , IA <:  SciMLBase.DAEInitializationAlgorithm , IK}
196+     H, C, J1, J2, J3, J4 , IA <:  SciMLBase.DAEInitializationAlgorithm , IK}
104197    """ 
105198    The indexes of differential equations in the linearized system. 
106199    """  
@@ -130,11 +223,22 @@ struct LinearizationFunction{
130223    Any required cache buffers. 
131224    """  
132225    caches:: C 
133-     #  TODO : Use DI?
134226    """ 
135-     A `ForwardDiff.Chunk` for taking the jacobian with respect to the inputs. 
227+     `PreparedJacobian` for calculating jacobian of `prob.f` w.r.t. `u` 
228+     """  
229+     uf_jac:: J1 
230+     """ 
231+     `PreparedJacobian` for calculating jacobian of `h` w.r.t. `u` 
136232    """  
137-     chunk:: Ch 
233+     h_jac:: J2 
234+     """ 
235+     `PreparedJacobian` for calculating jacobian of `prob.f` w.r.t. `p` 
236+     """  
237+     pf_jac:: J3 
238+     """ 
239+     `PreparedJacobian` for calculating jacobian of `h` w.r.t. `p` 
240+     """  
241+     hp_jac:: J4 
138242    """ 
139243    The initialization algorithm to use. 
140244    """  
@@ -188,25 +292,18 @@ function (linfun::LinearizationFunction)(u, p, t)
188292        if  ! success
189293            error (" Initialization algorithm $(linfun. initializealg)  failed with `u = $u ` and `p = $p `." 
190294        end 
191-         uf =  SciMLBase. UJacobianWrapper (fun, t, p)
192-         fg_xz =  ForwardDiff. jacobian (uf, u)
193-         h_xz =  ForwardDiff. jacobian (
194-             let  p =  p, t =  t, h =  linfun. h
195-                 xz ->  h (xz, p, t)
196-             end , u)
197-         pf =  SciMLBase. ParamJacobianWrapper (fun, t, u)
198-         fg_u =  jacobian_wrt_vars (pf, p, linfun. input_idxs, linfun. chunk)
295+         fg_xz =  linfun. uf_jac (u, DI. Constant (p), DI. Constant (t))
296+         h_xz =  linfun. h_jac (u, DI. Constant (p), DI. Constant (t))
297+         fg_u =  linfun. pf_jac ([p[idx] for  idx in  linfun. input_idxs],
298+             DI. Constant (u), DI. Constant (p), DI. Constant (t))
199299    else 
200300        linfun. num_states ==  0  || 
201301            error (" Number of unknown variables (0) does not match the number of input unknowns ($(length (u)) )" 
202302        fg_xz =  zeros (0 , 0 )
203303        h_xz =  fg_u =  zeros (0 , length (linfun. input_idxs))
204304    end 
205-     hp =  let  u =  u, t =  t, h =  linfun. h
206-         _hp (p) =  h (u, p, t)
207-         _hp
208-     end 
209-     h_u =  jacobian_wrt_vars (hp, p, linfun. input_idxs, linfun. chunk)
305+     h_u =  linfun. hp_jac ([p[idx] for  idx in  linfun. input_idxs],
306+         DI. Constant (u), DI. Constant (p), DI. Constant (t))
210307    (f_x =  fg_xz[linfun. diff_idxs, linfun. diff_idxs],
211308        f_z =  fg_xz[linfun. diff_idxs, linfun. alge_idxs],
212309        g_x =  fg_xz[linfun. alge_idxs, linfun. diff_idxs],
0 commit comments