@@ -214,7 +214,7 @@ class Solver:
214214 def __init__ (self , problem : Problem , * ,
215215 abstol : float = 1e-10 , reltol : float = 1e-10 ,
216216 sens_mode : Optional [str ] = None , scaling_factors : Optional [np .ndarray ] = None ,
217- constraints : Optional [np .ndarray ] = None , solver = 'BDF' ):
217+ constraints : Optional [np .ndarray ] = None , solver = 'BDF' , linear_solver = "dense" ):
218218 self ._problem = problem
219219 self ._user_data = problem .make_user_data ()
220220
@@ -242,7 +242,7 @@ def __init__(self, problem: Problem, *,
242242 self ._constraints_vec = sunode .from_numpy (constraints )
243243 check (lib .CVodeSetConstraints (self ._ode , self ._constraints_vec .c_ptr ))
244244
245- self ._make_linsol ()
245+ self ._make_linsol (linear_solver )
246246
247247 user_data_p = ffi .cast ('void *' , ffi .addressof (ffi .from_buffer (self ._user_data .data )))
248248 check (lib .CVodeSetUserData (self ._ode , user_data_p ))
@@ -252,12 +252,28 @@ def __init__(self, problem: Problem, *,
252252 sens_rhs = self ._problem .make_sundials_sensitivity_rhs ()
253253 self ._init_sens (sens_rhs , sens_mode )
254254
255- def _make_linsol (self ) -> None :
256- linsolver = check (lib .SUNLinSol_Dense (self ._state_buffer .c_ptr , self ._jac ))
257- check (lib .CVodeSetLinearSolver (self ._ode , linsolver , self ._jac ))
258-
259- self ._jac_func = self ._problem .make_sundials_jac_dense ()
260- check (lib .CVodeSetJacFn (self ._ode , self ._jac_func .cffi ))
255+ def _make_linsol (self , linear_solver ) -> None :
256+ if linear_solver == "dense" :
257+ linsolver = check (lib .SUNLinSol_Dense (self ._state_buffer .c_ptr , self ._jac ))
258+ check (lib .CVodeSetLinearSolver (self ._ode , linsolver , self ._jac ))
259+
260+ self ._jac_func = self ._problem .make_sundials_jac_dense ()
261+ check (lib .CVodeSetJacFn (self ._ode , self ._jac_func .cffi ))
262+ elif linear_solver == "dense_finitediff" :
263+ linsolver = check (lib .SUNLinSol_Dense (self ._state_buffer .c_ptr , self ._jac ))
264+ check (lib .CVodeSetLinearSolver (self ._ode , linsolver , self ._jac ))
265+ elif linear_solver == "spgmr_finitediff" :
266+ linsolver = check (lib .SUNLinSol_SPGMR (self ._state_buffer .c_ptr , lib .PREC_NONE , 5 ))
267+ check (lib .CVodeSetLinearSolver (self ._ode , linsolver , ffi .NULL ))
268+ check (lib .SUNLinSolInitialize_SPGMR (linsolver ))
269+ elif linear_solver == "spgmr" :
270+ linsolver = check (lib .SUNLinSol_SPGMR (self ._state_buffer .c_ptr , lib .PREC_NONE , 5 ))
271+ check (lib .CVodeSetLinearSolver (self ._ode , linsolver , ffi .NULL ))
272+ check (lib .SUNLinSolInitialize_SPGMR (linsolver ))
273+ jac_prod = self ._problem .make_sundials_jac_prod ()
274+ check (lib .CVodeSetJacTimes (self ._ode , ffi .NULL , jac_prod .cffi ))
275+ else :
276+ raise ValueError (f"Unknown linear solver: { linear_solver } " )
261277
262278 def _init_sens (self , sens_rhs , sens_mode , scaling_factors = None ) -> None :
263279 if sens_mode == 'simultaneous' :
0 commit comments