@@ -211,46 +211,74 @@ def current_stats(self) -> Dict[str, Any]:
211211
212212
213213class Solver :
214- def __init__ (self , problem : Problem , * ,
215- abstol : float = 1e-10 , reltol : float = 1e-10 ,
216- sens_mode : Optional [str ] = None , scaling_factors : Optional [np .ndarray ] = None ,
217- constraints : Optional [np .ndarray ] = None , solver = 'BDF' , linear_solver = "dense" ):
218- self ._problem = problem
219- self ._user_data = problem .make_user_data ()
220-
214+ def _init_sundials (self ):
221215 n_states = self ._problem .n_states
222216 n_params = self ._problem .n_params
223217
224218 self ._state_buffer = sunode .empty_vector (n_states )
225219 self ._state_buffer .data [:] = 0
226220 self ._jac = check (lib .SUNDenseMatrix (n_states , n_states ))
227- self ._constraints = constraints
228221
229- if solver == 'BDF' :
230- solver_kind = lib .CV_BDF
231- elif solver == 'ADAMS' :
232- solver_kind = lib .CV_ADAMS
233- else :
234- assert False
235- self ._ode = check (lib .CVodeCreate (solver_kind ))
236- rhs = problem .make_sundials_rhs ()
222+ self ._ode = check (lib .CVodeCreate (self ._solver_kind ))
223+ rhs = self ._problem .make_sundials_rhs ()
237224 check (lib .CVodeInit (self ._ode , rhs .cffi , 0. , self ._state_buffer .c_ptr ))
238225
239- self ._set_tolerances (abstol , reltol )
226+ self ._set_tolerances (self . _abstol , self . _reltol )
240227 if self ._constraints is not None :
241- assert constraints .shape == (n_states ,)
242- self ._constraints_vec = sunode .from_numpy (constraints )
228+ assert self . _constraints .shape == (n_states ,)
229+ self ._constraints_vec = sunode .from_numpy (self . _constraints )
243230 check (lib .CVodeSetConstraints (self ._ode , self ._constraints_vec .c_ptr ))
244231
245- self ._make_linsol (linear_solver )
232+ self ._make_linsol (self . _linear_solver_kind )
246233
247234 user_data_p = ffi .cast ('void *' , ffi .addressof (ffi .from_buffer (self ._user_data .data )))
248235 check (lib .CVodeSetUserData (self ._ode , user_data_p ))
249236
250- self ._compute_sens = sens_mode is not None
237+ self ._compute_sens = self . _sens_mode is not None
251238 if self ._compute_sens :
252239 sens_rhs = self ._problem .make_sundials_sensitivity_rhs ()
253- self ._init_sens (sens_rhs , sens_mode )
240+ self ._init_sens (sens_rhs , self ._sens_mode )
241+
242+ def __init__ (self , problem : Problem , * ,
243+ abstol : float = 1e-10 , reltol : float = 1e-10 ,
244+ sens_mode : Optional [str ] = None , scaling_factors : Optional [np .ndarray ] = None ,
245+ constraints : Optional [np .ndarray ] = None , solver = 'BDF' , linear_solver = "dense" ):
246+ self ._problem = problem
247+ self ._user_data = problem .make_user_data ()
248+ self ._constraints = constraints
249+
250+ self ._abstol = abstol
251+ self ._reltol = reltol
252+
253+ self ._linear_solver_kind = linear_solver
254+ self ._sens_mode = sens_mode
255+
256+ if solver == 'BDF' :
257+ self ._solver_kind = lib .CV_BDF
258+ elif solver == 'ADAMS' :
259+ self ._solver_kind = lib .CV_ADAMS
260+ else :
261+ assert False
262+
263+ self ._state_names = [
264+ "_problem" ,
265+ "_user_data" ,
266+ "_constraints" ,
267+ "_abstol" ,
268+ "_reltol" ,
269+ "_linear_solver_kind" ,
270+ "_sens_mode" ,
271+ "_solver_kind" ,
272+ ]
273+
274+ self ._init_sundials ()
275+
276+ def __getstate__ (self ):
277+ return {name : self .__dict__ [name ] for name in self ._state_names }
278+
279+ def __setstate__ (self , state ):
280+ self .__dict__ .update (state )
281+ self ._init_sundials ()
254282
255283 def _make_linsol (self , linear_solver ) -> None :
256284 if linear_solver == "dense" :
0 commit comments