@@ -217,7 +217,6 @@ def _init_sundials(self):
217217
218218 self ._state_buffer = sunode .empty_vector (n_states )
219219 self ._state_buffer .data [:] = 0
220- self ._jac = check (lib .SUNDenseMatrix (n_states , n_states ))
221220
222221 self ._ode = check (lib .CVodeCreate (self ._solver_kind ))
223222 rhs = self ._problem .make_sundials_rhs ()
@@ -233,17 +232,26 @@ def _init_sundials(self):
233232 self ._constraints_vec = sunode .from_numpy (self ._constraints )
234233 check (lib .CVodeSetConstraints (self ._ode , self ._constraints_vec .c_ptr ))
235234
236- self ._make_linsol (self ._linear_solver_kind )
235+ self ._make_linsol (self ._linear_solver_kind , ** self . _linear_solver_kwargs )
237236
238237 self ._compute_sens = self ._sens_mode is not None
239238 if self ._compute_sens :
240239 sens_rhs = self ._problem .make_sundials_sensitivity_rhs ()
241240 self ._init_sens (sens_rhs , self ._sens_mode )
242241
243- def __init__ (self , problem : Problem , * ,
244- abstol : float = 1e-10 , reltol : float = 1e-10 ,
245- sens_mode : Optional [str ] = None , scaling_factors : Optional [np .ndarray ] = None ,
246- constraints : Optional [np .ndarray ] = None , solver = 'BDF' , linear_solver = "dense" ):
242+ def __init__ (
243+ self ,
244+ problem : Problem ,
245+ * ,
246+ abstol : float = 1e-10 ,
247+ reltol : float = 1e-10 ,
248+ sens_mode : Optional [str ] = None ,
249+ scaling_factors : Optional [np .ndarray ] = None ,
250+ constraints : Optional [np .ndarray ] = None ,
251+ solver = 'BDF' ,
252+ linear_solver = "dense" ,
253+ linear_solver_kwargs = None ,
254+ ):
247255 self ._problem = problem
248256 self ._user_data = problem .make_user_data ()
249257 self ._constraints = constraints
@@ -252,6 +260,7 @@ def __init__(self, problem: Problem, *,
252260 self ._reltol = reltol
253261
254262 self ._linear_solver_kind = linear_solver
263+ self ._linear_solver_kwargs = linear_solver_kwargs
255264 self ._sens_mode = sens_mode
256265
257266 if solver == 'BDF' :
@@ -268,6 +277,7 @@ def __init__(self, problem: Problem, *,
268277 "_abstol" ,
269278 "_reltol" ,
270279 "_linear_solver_kind" ,
280+ "_linear_solver_kwargs" ,
271281 "_sens_mode" ,
272282 "_solver_kind" ,
273283 ]
@@ -281,14 +291,17 @@ def __setstate__(self, state):
281291 self .__dict__ .update (state )
282292 self ._init_sundials ()
283293
284- def _make_linsol (self , linear_solver ) -> None :
294+ def _make_linsol (self , linear_solver , ** kwargs ) -> None :
295+ n_states = self ._problem .n_states
285296 if linear_solver == "dense" :
297+ self ._jac = check (lib .SUNDenseMatrix (n_states , n_states ))
286298 linsolver = check (lib .SUNLinSol_Dense (self ._state_buffer .c_ptr , self ._jac ))
287299 check (lib .CVodeSetLinearSolver (self ._ode , linsolver , self ._jac ))
288300
289301 self ._jac_func = self ._problem .make_sundials_jac_dense ()
290302 check (lib .CVodeSetJacFn (self ._ode , self ._jac_func .cffi ))
291303 elif linear_solver == "dense_finitediff" :
304+ self ._jac = check (lib .SUNDenseMatrix (n_states , n_states ))
292305 linsolver = check (lib .SUNLinSol_Dense (self ._state_buffer .c_ptr , self ._jac ))
293306 check (lib .CVodeSetLinearSolver (self ._ode , linsolver , self ._jac ))
294307 elif linear_solver == "spgmr_finitediff" :
@@ -301,6 +314,14 @@ def _make_linsol(self, linear_solver) -> None:
301314 check (lib .SUNLinSolInitialize_SPGMR (linsolver ))
302315 jac_prod = self ._problem .make_sundials_jac_prod ()
303316 check (lib .CVodeSetJacTimes (self ._ode , ffi .NULL , jac_prod .cffi ))
317+ elif linear_solver == "band" :
318+ upper_bandwidth = kwargs .get ("upper_bandwidth" , None )
319+ lower_bandwidth = kwargs .get ("lower_bandwidth" , None )
320+ if upper_bandwidth is None or lower_bandwidth is None :
321+ raise ValueError ("Specify 'lower_bandwidth' and 'upper_bandwidth' arguments for banded solver." )
322+ self ._jac = check (lib .SUNBandMatrix (n_states , upper_bandwidth , lower_bandwidth ))
323+ linsolver = check (lib .SUNLinSol_Band (self ._state_buffer .c_ptr , self ._jac ))
324+ check (lib .CVodeSetLinearSolver (self ._ode , linsolver , self ._jac ))
304325 else :
305326 raise ValueError (f"Unknown linear solver: { linear_solver } " )
306327
0 commit comments