@@ -221,8 +221,12 @@ def _init_sundials(self):
221221
222222 self ._ode = check (lib .CVodeCreate (self ._solver_kind ))
223223 rhs = self ._problem .make_sundials_rhs ()
224+ self ._rhs = rhs
224225 check (lib .CVodeInit (self ._ode , rhs .cffi , 0. , self ._state_buffer .c_ptr ))
225226
227+ user_data_p = ffi .cast ('void *' , ffi .addressof (ffi .from_buffer (self ._user_data .data )))
228+ check (lib .CVodeSetUserData (self ._ode , user_data_p ))
229+
226230 self ._set_tolerances (self ._abstol , self ._reltol )
227231 if self ._constraints is not None :
228232 assert self ._constraints .shape == (n_states ,)
@@ -231,9 +235,6 @@ def _init_sundials(self):
231235
232236 self ._make_linsol (self ._linear_solver_kind )
233237
234- user_data_p = ffi .cast ('void *' , ffi .addressof (ffi .from_buffer (self ._user_data .data )))
235- check (lib .CVodeSetUserData (self ._ode , user_data_p ))
236-
237238 self ._compute_sens = self ._sens_mode is not None
238239 if self ._compute_sens :
239240 sens_rhs = self ._problem .make_sundials_sensitivity_rhs ()
@@ -340,15 +341,20 @@ def _init_sens(self, sens_rhs, sens_mode, scaling_factors=None) -> None:
340341 def _set_tolerances (self , atol = None , rtol = None ):
341342 atol = np .array (atol )
342343 rtol = np .array (rtol )
344+ n_states = self ._problem .n_states
343345 if atol .ndim == 1 and rtol .ndim == 1 :
344346 atol = sunode .from_numpy (atol )
345347 rtol = sunode .from_numpy (rtol )
348+ assert atol .shape == (n_states ,)
349+ assert rtol .shape == (n_states ,)
346350 check (lib .CVodeVVtolerances (self ._ode , rtol .c_ptr , atol .c_ptr ))
347351 elif atol .ndim == 1 and rtol .ndim == 0 :
348352 atol = sunode .from_numpy (atol )
353+ assert atol .shape == (n_states ,)
349354 check (lib .CVodeSVtolerances (self ._ode , rtol , atol .c_ptr ))
350355 elif atol .ndim == 0 and rtol .ndim == 1 :
351356 rtol = sunode .from_numpy (rtol )
357+ assert rtol .shape == (n_states ,)
352358 check (lib .CVodeVStolerances (self ._ode , rtol .c_ptr , atol ))
353359 elif atol .ndim == 0 and rtol .ndim == 0 :
354360 check (lib .CVodeSStolerances (self ._ode , rtol , atol ))
@@ -416,6 +422,7 @@ def solve(self, t0, tvals, y0, y_out, *, sens0=None, sens_out=None, max_retries=
416422 TOO_MUCH_WORK = lib .CV_TOO_MUCH_WORK
417423
418424 n_params = self ._problem .n_params
425+ n_states = self ._problem .n_states
419426
420427 state_data = self ._state_buffer .data
421428 state_c_ptr = self ._state_buffer .c_ptr
@@ -428,6 +435,9 @@ def solve(self, t0, tvals, y0, y_out, *, sens0=None, sens_out=None, max_retries=
428435
429436 if y0 .dtype == self ._problem .state_dtype :
430437 y0 = y0 [None ].view (np .float64 )
438+
439+ if y0 .shape != (n_states ,):
440+ raise ValueError (f"y0 should have shape { (n_states ,)} but has shape { y0 .shape } ." )
431441 state_data [:] = y0
432442
433443 time_p = ffi .new ('double*' )
0 commit comments