Skip to content

Commit ccdcd0f

Browse files
committed
Allow pickling of solver
1 parent b13d83b commit ccdcd0f

File tree

2 files changed

+54
-23
lines changed

2 files changed

+54
-23
lines changed

sunode/solver.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -211,46 +211,74 @@ def current_stats(self) -> Dict[str, Any]:
211211

212212

213213
class Solver:
214-
def __init__(self, problem: Problem, *,
215-
abstol: float = 1e-10, reltol: float = 1e-10,
216-
sens_mode: Optional[str] = None, scaling_factors: Optional[np.ndarray] = None,
217-
constraints: Optional[np.ndarray] = None, solver='BDF', linear_solver="dense"):
218-
self._problem = problem
219-
self._user_data = problem.make_user_data()
220-
214+
def _init_sundials(self):
221215
n_states = self._problem.n_states
222216
n_params = self._problem.n_params
223217

224218
self._state_buffer = sunode.empty_vector(n_states)
225219
self._state_buffer.data[:] = 0
226220
self._jac = check(lib.SUNDenseMatrix(n_states, n_states))
227-
self._constraints = constraints
228221

229-
if solver == 'BDF':
230-
solver_kind = lib.CV_BDF
231-
elif solver == 'ADAMS':
232-
solver_kind = lib.CV_ADAMS
233-
else:
234-
assert False
235-
self._ode = check(lib.CVodeCreate(solver_kind))
236-
rhs = problem.make_sundials_rhs()
222+
self._ode = check(lib.CVodeCreate(self._solver_kind))
223+
rhs = self._problem.make_sundials_rhs()
237224
check(lib.CVodeInit(self._ode, rhs.cffi, 0., self._state_buffer.c_ptr))
238225

239-
self._set_tolerances(abstol, reltol)
226+
self._set_tolerances(self._abstol, self._reltol)
240227
if self._constraints is not None:
241-
assert constraints.shape == (n_states,)
242-
self._constraints_vec = sunode.from_numpy(constraints)
228+
assert self._constraints.shape == (n_states,)
229+
self._constraints_vec = sunode.from_numpy(self._constraints)
243230
check(lib.CVodeSetConstraints(self._ode, self._constraints_vec.c_ptr))
244231

245-
self._make_linsol(linear_solver)
232+
self._make_linsol(self._linear_solver_kind)
246233

247234
user_data_p = ffi.cast('void *', ffi.addressof(ffi.from_buffer(self._user_data.data)))
248235
check(lib.CVodeSetUserData(self._ode, user_data_p))
249236

250-
self._compute_sens = sens_mode is not None
237+
self._compute_sens = self._sens_mode is not None
251238
if self._compute_sens:
252239
sens_rhs = self._problem.make_sundials_sensitivity_rhs()
253-
self._init_sens(sens_rhs, sens_mode)
240+
self._init_sens(sens_rhs, self._sens_mode)
241+
242+
def __init__(self, problem: Problem, *,
243+
abstol: float = 1e-10, reltol: float = 1e-10,
244+
sens_mode: Optional[str] = None, scaling_factors: Optional[np.ndarray] = None,
245+
constraints: Optional[np.ndarray] = None, solver='BDF', linear_solver="dense"):
246+
self._problem = problem
247+
self._user_data = problem.make_user_data()
248+
self._constraints = constraints
249+
250+
self._abstol = abstol
251+
self._reltol = reltol
252+
253+
self._linear_solver_kind = linear_solver
254+
self._sens_mode = sens_mode
255+
256+
if solver == 'BDF':
257+
self._solver_kind = lib.CV_BDF
258+
elif solver == 'ADAMS':
259+
self._solver_kind = lib.CV_ADAMS
260+
else:
261+
assert False
262+
263+
self._state_names = [
264+
"_problem",
265+
"_user_data",
266+
"_constraints",
267+
"_abstol",
268+
"_reltol",
269+
"_linear_solver_kind",
270+
"_sens_mode",
271+
"_solver_kind",
272+
]
273+
274+
self._init_sundials()
275+
276+
def __getstate__(self):
277+
return {name: self.__dict__[name] for name in self._state_names}
278+
279+
def __setstate__(self, state):
280+
self.__dict__.update(state)
281+
self._init_sundials()
254282

255283
def _make_linsol(self, linear_solver) -> None:
256284
if linear_solver == "dense":

sunode/symode/problem.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
Shape = Tuple[int, ...]
1919

2020

21+
def _identity(x):
22+
return x
23+
2124
class SympyProblem(problem.Problem):
2225
def __init__(
2326
self,
@@ -43,7 +46,7 @@ def __init__(
4346
self._rhs_sympy_func = rhs_sympy
4447

4548
if simplify is None:
46-
simplify = lambda x: x
49+
simplify = _identity
4750
self._simplify = np.vectorize(simplify)
4851

4952
def check_dtype(dtype: np.dtype, path: Optional[str] = None) -> None:

0 commit comments

Comments
 (0)