Skip to content

Commit 270df03

Browse files
astoerikoaseyboldt
authored andcommitted
Add option to use linear solver for banded matrix
1 parent f1f7f79 commit 270df03

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

sunode/solver.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ def _init_sundials(self):
217217

218218
self._state_buffer = sunode.empty_vector(n_states)
219219
self._state_buffer.data[:] = 0
220-
self._jac = check(lib.SUNDenseMatrix(n_states, n_states))
221220

222221
self._ode = check(lib.CVodeCreate(self._solver_kind))
223222
rhs = self._problem.make_sundials_rhs()
@@ -233,17 +232,26 @@ def _init_sundials(self):
233232
self._constraints_vec = sunode.from_numpy(self._constraints)
234233
check(lib.CVodeSetConstraints(self._ode, self._constraints_vec.c_ptr))
235234

236-
self._make_linsol(self._linear_solver_kind)
235+
self._make_linsol(self._linear_solver_kind, **self._linear_solver_kwargs)
237236

238237
self._compute_sens = self._sens_mode is not None
239238
if self._compute_sens:
240239
sens_rhs = self._problem.make_sundials_sensitivity_rhs()
241240
self._init_sens(sens_rhs, self._sens_mode)
242241

243-
def __init__(self, problem: Problem, *,
244-
abstol: float = 1e-10, reltol: float = 1e-10,
245-
sens_mode: Optional[str] = None, scaling_factors: Optional[np.ndarray] = None,
246-
constraints: Optional[np.ndarray] = None, solver='BDF', linear_solver="dense"):
242+
def __init__(
243+
self,
244+
problem: Problem,
245+
*,
246+
abstol: float = 1e-10,
247+
reltol: float = 1e-10,
248+
sens_mode: Optional[str] = None,
249+
scaling_factors: Optional[np.ndarray] = None,
250+
constraints: Optional[np.ndarray] = None,
251+
solver='BDF',
252+
linear_solver="dense",
253+
linear_solver_kwargs=None,
254+
):
247255
self._problem = problem
248256
self._user_data = problem.make_user_data()
249257
self._constraints = constraints
@@ -252,6 +260,7 @@ def __init__(self, problem: Problem, *,
252260
self._reltol = reltol
253261

254262
self._linear_solver_kind = linear_solver
263+
self._linear_solver_kwargs = linear_solver_kwargs
255264
self._sens_mode = sens_mode
256265

257266
if solver == 'BDF':
@@ -268,6 +277,7 @@ def __init__(self, problem: Problem, *,
268277
"_abstol",
269278
"_reltol",
270279
"_linear_solver_kind",
280+
"_linear_solver_kwargs",
271281
"_sens_mode",
272282
"_solver_kind",
273283
]
@@ -281,14 +291,17 @@ def __setstate__(self, state):
281291
self.__dict__.update(state)
282292
self._init_sundials()
283293

284-
def _make_linsol(self, linear_solver) -> None:
294+
def _make_linsol(self, linear_solver, **kwargs) -> None:
295+
n_states = self._problem.n_states
285296
if linear_solver == "dense":
297+
self._jac = check(lib.SUNDenseMatrix(n_states, n_states))
286298
linsolver = check(lib.SUNLinSol_Dense(self._state_buffer.c_ptr, self._jac))
287299
check(lib.CVodeSetLinearSolver(self._ode, linsolver, self._jac))
288300

289301
self._jac_func = self._problem.make_sundials_jac_dense()
290302
check(lib.CVodeSetJacFn(self._ode, self._jac_func.cffi))
291303
elif linear_solver == "dense_finitediff":
304+
self._jac = check(lib.SUNDenseMatrix(n_states, n_states))
292305
linsolver = check(lib.SUNLinSol_Dense(self._state_buffer.c_ptr, self._jac))
293306
check(lib.CVodeSetLinearSolver(self._ode, linsolver, self._jac))
294307
elif linear_solver == "spgmr_finitediff":
@@ -301,6 +314,14 @@ def _make_linsol(self, linear_solver) -> None:
301314
check(lib.SUNLinSolInitialize_SPGMR(linsolver))
302315
jac_prod = self._problem.make_sundials_jac_prod()
303316
check(lib.CVodeSetJacTimes(self._ode, ffi.NULL, jac_prod.cffi))
317+
elif linear_solver == "band":
318+
upper_bandwidth = kwargs.get("upper_bandwidth", None)
319+
lower_bandwidth = kwargs.get("lower_bandwidth", None)
320+
if upper_bandwidth is None or lower_bandwidth is None:
321+
raise ValueError("Specify 'lower_bandwidth' and 'upper_bandwidth' arguments for banded solver.")
322+
self._jac = check(lib.SUNBandMatrix(n_states, upper_bandwidth, lower_bandwidth))
323+
linsolver = check(lib.SUNLinSol_Band(self._state_buffer.c_ptr, self._jac))
324+
check(lib.CVodeSetLinearSolver(self._ode, linsolver, self._jac))
304325
else:
305326
raise ValueError(f"Unknown linear solver: {linear_solver}")
306327

0 commit comments

Comments
 (0)