@@ -25,6 +25,7 @@ The `simplified_sys` has undergone [`structural_simplify`](@ref) and had any occ
2525  - `simplify`: Apply simplification in tearing. 
2626  - `initialize`: If true, a check is performed to ensure that the operating point is consistent (satisfies algebraic equations). If the op is not consistent, initialization is performed. 
2727  - `initialization_solver_alg`: A NonlinearSolve algorithm to use for solving for a feasible set of state and algebraic variables that satisfies the specified operating point. 
28+   - `autodiff`: An `ADType` supported by DifferentiationInterface.jl to use for calculating the necessary jacobians. 
2829  - `kwargs`: Are passed on to `find_solvables!` 
2930
3031See also [`linearize`](@ref) which provides a higher-level interface. 
@@ -39,6 +40,7 @@ function linearization_function(sys::AbstractSystem, inputs,
3940        p =  DiffEqBase. NullParameters (),
4041        zero_dummy_der =  false ,
4142        initialization_solver_alg =  TrustRegion (),
43+         autodiff =  AutoForwardDiff (),
4244        eval_expression =  false , eval_module =  @__MODULE__ ,
4345        warn_initialize_determined =  true ,
4446        guesses =  Dict (),
@@ -82,13 +84,89 @@ function linearization_function(sys::AbstractSystem, inputs,
8284    initialization_kwargs =  (;
8385        abstol =  initialization_abstol, reltol =  initialization_reltol,
8486        nlsolve_alg =  initialization_solver_alg)
87+ 
88+     p =  parameter_values (prob)
89+     t0 =  current_time (prob)
90+     inputvals =  [p[idx] for  idx in  input_idxs]
91+ 
92+     uf_fun =  let  fun =  prob. f
93+         function  uff (du, u, p, t)
94+             SciMLBase. UJacobianWrapper (fun, t, p)(du, u)
95+         end 
96+     end 
97+     uf_jac =  PreparedJacobian {true} (uf_fun, similar (prob. u0), autodiff, prob. u0, DI. Constant (p), DI. Constant (t0))
98+     #  observed function is a `GeneratedFunctionWrapper` with iip component
99+     h_jac =  PreparedJacobian {true} (h, similar (prob. u0, size (outputs)), autodiff, prob. u0, DI. Constant (p), DI. Constant (t0))
100+     pf_fun =  let  fun =  prob. f, setter =  setp_oop (sys, input_idxs)
101+         function  pff (du, input, u, p, t)
102+             p =  setter (p, input)
103+             SciMLBase. ParamJacobianWrapper (fun, t, u)(du, p)
104+         end 
105+     end 
106+     pf_jac =  PreparedJacobian {true} (pf_fun, similar (prob. u0), autodiff, inputvals, DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
107+     hp_fun =  let  fun =  h, setter =  setp_oop (sys, input_idxs)
108+         function  hpf (du, input, u, p, t)
109+             p =  setter (p, input)
110+             fun (du, u, p, t)
111+             return  du
112+         end 
113+     end 
114+     hp_jac =  PreparedJacobian {true} (hp_fun, similar (prob. u0, size (outputs)), autodiff, inputvals, DI. Constant (prob. u0), DI. Constant (p), DI. Constant (t0))
115+     
85116    lin_fun =  LinearizationFunction (
86117        diff_idxs, alge_idxs, input_idxs, length (unknowns (sys)),
87-         prob, h, u0 ===  nothing  ?  nothing  :  similar (u0),
88-         ForwardDiff . Chunk (input_idxs) , initializealg, initialization_kwargs)
118+         prob, h, u0 ===  nothing  ?  nothing  :  similar (u0), uf_jac, h_jac, pf_jac, 
119+         hp_jac , initializealg, initialization_kwargs)
89120    return  lin_fun, sys
90121end 
91122
123+ """ 
124+     $(TYPEDEF)  
125+ 
126+ Callable struct which stores a function and its prepared `DI.jacobian`. Calling with the 
127+ appropriate arguments for DI returns the jacobian. 
128+ 
129+ # Fields 
130+ 
131+ $(TYPEDFIELDS) 
132+ """ 
133+ struct  PreparedJacobian{iip, P, F, B, A}
134+     """ 
135+     The preparation object. 
136+     """  
137+     prep:: P 
138+     """ 
139+     The function whose jacobian is calculated. 
140+     """  
141+     f:: F 
142+     """ 
143+     Buffer for in-place functions. 
144+     """  
145+     buf:: B 
146+     """ 
147+     ADType to use for differentiation. 
148+     """  
149+     autodiff:: A 
150+ end 
151+ 
152+ function  PreparedJacobian {true} (f, buf, autodiff, args... )
153+     prep =  DI. prepare_jacobian (f, buf, autodiff, args... )
154+     return  PreparedJacobian {true, typeof(prep), typeof(f), typeof(buf), typeof(autodiff)} (prep, f, buf, autodiff)
155+ end 
156+ 
157+ function  PreparedJacobian {false} (f, autodiff, args... )
158+     prep =  DI. prepare_jacobian (f, autodiff, args... )
159+     return  PreparedJacobian {true, typeof(prep), typeof(f), Nothing, typeof(autodiff)} (prep, f, nothing )
160+ end 
161+ 
162+ function  (pj:: PreparedJacobian{true} )(args... )
163+     DI. jacobian (pj. f, pj. buf, pj. prep, pj. autodiff, args... )
164+ end 
165+ 
166+ function  (pj:: PreparedJacobian{false} )(args... )
167+     DI. jacobian (pj. f, pj. prep, pj. autodiff, args... )
168+ end 
169+ 
92170""" 
93171    $(TYPEDEF)  
94172
@@ -100,7 +178,7 @@ $(TYPEDFIELDS)
100178""" 
101179struct  LinearizationFunction{
102180    DI <:  AbstractVector{Int} , AI <:  AbstractVector{Int} , II, P <:  ODEProblem ,
103-     H, C, Ch , IA <:  SciMLBase.DAEInitializationAlgorithm , IK}
181+     H, C, J1, J2, J3, J4 , IA <:  SciMLBase.DAEInitializationAlgorithm , IK}
104182    """ 
105183    The indexes of differential equations in the linearized system. 
106184    """  
@@ -130,11 +208,22 @@ struct LinearizationFunction{
130208    Any required cache buffers. 
131209    """  
132210    caches:: C 
133-     #  TODO : Use DI?
134211    """ 
135-     A `ForwardDiff.Chunk` for taking the jacobian with respect to the inputs. 
212+     `PreparedJacobian` for calculating jacobian of `prob.f` w.r.t. `u` 
213+     """  
214+     uf_jac:: J1 
215+     """ 
216+     `PreparedJacobian` for calculating jacobian of `h` w.r.t. `u` 
136217    """  
137-     chunk:: Ch 
218+     h_jac:: J2 
219+     """ 
220+     `PreparedJacobian` for calculating jacobian of `prob.f` w.r.t. `p` 
221+     """  
222+     pf_jac:: J3 
223+     """ 
224+     `PreparedJacobian` for calculating jacobian of `h` w.r.t. `p` 
225+     """  
226+     hp_jac:: J4 
138227    """ 
139228    The initialization algorithm to use. 
140229    """  
@@ -188,25 +277,16 @@ function (linfun::LinearizationFunction)(u, p, t)
188277        if  ! success
189278            error (" Initialization algorithm $(linfun. initializealg)  failed with `u = $u ` and `p = $p `." 
190279        end 
191-         uf =  SciMLBase. UJacobianWrapper (fun, t, p)
192-         fg_xz =  ForwardDiff. jacobian (uf, u)
193-         h_xz =  ForwardDiff. jacobian (
194-             let  p =  p, t =  t, h =  linfun. h
195-                 xz ->  h (xz, p, t)
196-             end , u)
197-         pf =  SciMLBase. ParamJacobianWrapper (fun, t, u)
198-         fg_u =  jacobian_wrt_vars (pf, p, linfun. input_idxs, linfun. chunk)
280+         fg_xz =  linfun. uf_jac (u, DI. Constant (p), DI. Constant (t))
281+         h_xz =  linfun. h_jac (u, DI. Constant (p), DI. Constant (t))
282+         fg_u =  linfun. pf_jac ([p[idx] for  idx in  linfun. input_idxs], DI. Constant (u), DI. Constant (p), DI. Constant (t))
199283    else 
200284        linfun. num_states ==  0  || 
201285            error (" Number of unknown variables (0) does not match the number of input unknowns ($(length (u)) )" 
202286        fg_xz =  zeros (0 , 0 )
203287        h_xz =  fg_u =  zeros (0 , length (linfun. input_idxs))
204288    end 
205-     hp =  let  u =  u, t =  t, h =  linfun. h
206-         _hp (p) =  h (u, p, t)
207-         _hp
208-     end 
209-     h_u =  jacobian_wrt_vars (hp, p, linfun. input_idxs, linfun. chunk)
289+     h_u =  linfun. hp_jac ([p[idx] for  idx in  linfun. input_idxs], DI. Constant (u), DI. Constant (p), DI. Constant (t))
210290    (f_x =  fg_xz[linfun. diff_idxs, linfun. diff_idxs],
211291        f_z =  fg_xz[linfun. diff_idxs, linfun. alge_idxs],
212292        g_x =  fg_xz[linfun. alge_idxs, linfun. diff_idxs],
0 commit comments