@@ -161,6 +161,8 @@ where
161161 ///
162162 /// # Post-condition
163163 /// After the solver finishes, the internal state of the solver is at time `final_time`.
164+ /// If a root is found, the solver stops early. The internal state is moved to the root time,
165+ /// and the root time/value are returned as the last entry.
164166 #[ allow( clippy:: type_complexity) ]
165167 fn solve (
166168 & mut self ,
@@ -176,12 +178,40 @@ where
176178 // do the main loop
177179 write_out ( self , & mut ret_y, & mut ret_t, & mut tmp_nout) ;
178180 self . set_stop_time ( final_time) ?;
179- while self . step ( ) ? != OdeSolverStopReason :: TstopReached {
180- write_out ( self , & mut ret_y, & mut ret_t, & mut tmp_nout) ;
181+ loop {
182+ match self . step ( ) ? {
183+ OdeSolverStopReason :: InternalTimestep => {
184+ write_out ( self , & mut ret_y, & mut ret_t, & mut tmp_nout) ;
185+ }
186+ OdeSolverStopReason :: TstopReached => {
187+ write_out ( self , & mut ret_y, & mut ret_t, & mut tmp_nout) ;
188+ break ;
189+ }
190+ OdeSolverStopReason :: RootFound ( t_root) => {
191+ let nstates = self . problem ( ) . eqn . rhs ( ) . nstates ( ) ;
192+ let mut y_root =
193+ Eqn :: V :: zeros ( nstates, self . problem ( ) . context ( ) . clone ( ) ) ;
194+ self . interpolate_inplace ( t_root, & mut y_root) ?;
195+ let integrate_out = self . problem ( ) . integrate_out ;
196+ let mut g_root = None ;
197+ if integrate_out {
198+ let mut g = self . state ( ) . g . clone ( ) ;
199+ self . interpolate_out_inplace ( t_root, & mut g) ?;
200+ g_root = Some ( g) ;
201+ }
202+ {
203+ let state = self . state_mut ( ) ;
204+ state. y . copy_from ( & y_root) ;
205+ * state. t = t_root;
206+ if let Some ( g) = g_root. as_ref ( ) {
207+ state. g . copy_from ( g) ;
208+ }
209+ }
210+ write_out ( self , & mut ret_y, & mut ret_t, & mut tmp_nout) ;
211+ break ;
212+ }
213+ }
181214 }
182-
183- // store the final step
184- write_out ( self , & mut ret_y, & mut ret_t, & mut tmp_nout) ;
185215 let ntimes = ret_t. len ( ) ;
186216 ret_y. resize_cols ( ntimes) ;
187217 Ok ( ( ret_y, ret_t) )
@@ -202,6 +232,8 @@ where
202232 ///
203233 /// # Post-condition
204234 /// After the solver finishes, the internal state of the solver is at time `t_eval[t_eval.len()-1]`.
235+ /// If a root is found, the solver stops early. The internal state is moved to the root time,
236+ /// and the last column corresponds to the root time (which may not be in `t_eval`).
205237 fn solve_dense (
206238 & mut self ,
207239 t_eval : & [ Eqn :: T ] ,
@@ -214,14 +246,49 @@ where
214246
215247 // do loop
216248 self . set_stop_time ( t_eval[ t_eval. len ( ) - 1 ] ) ?;
217- let mut step_reason = OdeSolverStopReason :: InternalTimestep ;
218249 for ( i, t) in t_eval. iter ( ) . enumerate ( ) {
219250 while self . state ( ) . t < * t {
220- step_reason = self . step ( ) ?;
251+ match self . step ( ) ? {
252+ OdeSolverStopReason :: InternalTimestep => { }
253+ OdeSolverStopReason :: TstopReached => break ,
254+ OdeSolverStopReason :: RootFound ( t_root) => {
255+ self . interpolate_inplace ( t_root, & mut tmp_nstates) ?;
256+ let integrate_out = self . problem ( ) . integrate_out ;
257+ let mut g_root = None ;
258+ if integrate_out {
259+ let mut g = self . state ( ) . g . clone ( ) ;
260+ self . interpolate_out_inplace ( t_root, & mut g) ?;
261+ g_root = Some ( g) ;
262+ }
263+ {
264+ let state = self . state_mut ( ) ;
265+ state. y . copy_from ( & tmp_nstates) ;
266+ * state. t = t_root;
267+ if let Some ( g) = g_root. as_ref ( ) {
268+ state. g . copy_from ( g) ;
269+ }
270+ }
271+ {
272+ let mut y_out = ret. column_mut ( i) ;
273+ if integrate_out {
274+ y_out. copy_from ( g_root. as_ref ( ) . unwrap ( ) ) ;
275+ } else {
276+ match self . problem ( ) . eqn . out ( ) {
277+ Some ( out) => {
278+ out. call_inplace ( & tmp_nstates, t_root, & mut tmp_nout) ;
279+ y_out. copy_from ( & tmp_nout) ;
280+ }
281+ None => y_out. copy_from ( & tmp_nstates) ,
282+ }
283+ }
284+ }
285+ ret. resize_cols ( i + 1 ) ;
286+ return Ok ( ret) ;
287+ }
288+ }
221289 }
222290 dense_write_out ( self , & mut ret, t_eval, i, & mut tmp_nout, & mut tmp_nstates) ?;
223291 }
224- assert_eq ! ( step_reason, OdeSolverStopReason :: TstopReached ) ;
225292 Ok ( ret)
226293 }
227294
@@ -543,10 +610,12 @@ where
543610mod test {
544611 use crate :: {
545612 error:: { DiffsolError , OdeSolverError } ,
613+ matrix:: MatrixCommon ,
546614 matrix:: dense_nalgebra_serial:: NalgebraMat ,
547615 ode_equations:: test_models:: exponential_decay:: {
548616 exponential_decay_problem, exponential_decay_problem_adjoint,
549617 exponential_decay_problem_sens, exponential_decay_problem_sens_with_out,
618+ exponential_decay_problem_with_root,
550619 } ,
551620 scale, AdjointOdeSolverMethod , DenseMatrix , NalgebraLU , NalgebraVec , OdeEquations ,
552621 OdeSolverMethod , Op , SensitivitiesOdeSolverMethod , Vector , VectorView ,
@@ -569,6 +638,22 @@ mod test {
569638 }
570639 }
571640
641+ #[ test]
642+ fn test_solve_stops_on_root ( ) {
643+ let ( problem, _soln) = exponential_decay_problem_with_root :: < NalgebraMat < f64 > > ( false ) ;
644+ let mut s = problem. bdf :: < NalgebraLU < f64 > > ( ) . unwrap ( ) ;
645+
646+ let ( y, t) = s. solve ( 10.0 ) . unwrap ( ) ;
647+ let t_root = -0.6_f64 . ln ( ) / 0.1 ;
648+ let t_last = * t. last ( ) . unwrap ( ) ;
649+ assert ! ( ( t_last - t_root) . abs( ) < 1e-3 ) ;
650+ assert ! ( ( s. state( ) . t - t_root) . abs( ) < 1e-3 ) ;
651+
652+ let y_last = y. column ( y. ncols ( ) - 1 ) . into_owned ( ) ;
653+ let expected = NalgebraVec :: from_vec ( vec ! [ 0.6 , 0.6 ] , * problem. context ( ) ) ;
654+ y_last. assert_eq_norm ( & expected, & problem. atol , problem. rtol , 15.0 ) ;
655+ }
656+
572657 #[ test]
573658 fn test_solve_integrate_out ( ) {
574659 let ( problem, _soln) = exponential_decay_problem_adjoint :: < NalgebraMat < f64 > > ( true ) ;
@@ -604,6 +689,22 @@ mod test {
604689 }
605690 }
606691
692+ #[ test]
693+ fn test_dense_solve_stops_on_root ( ) {
694+ let ( problem, _soln) = exponential_decay_problem_with_root :: < NalgebraMat < f64 > > ( false ) ;
695+ let mut s = problem. bdf :: < NalgebraLU < f64 > > ( ) . unwrap ( ) ;
696+
697+ let t_eval = ( 0 ..=10 ) . map ( |i| i as f64 ) . collect :: < Vec < _ > > ( ) ;
698+ let y = s. solve_dense ( t_eval. as_slice ( ) ) . unwrap ( ) ;
699+ let t_root = -0.6_f64 . ln ( ) / 0.1 ;
700+ assert ! ( ( s. state( ) . t - t_root) . abs( ) < 1e-3 ) ;
701+ assert ! ( y. ncols( ) < t_eval. len( ) ) ;
702+
703+ let y_last = y. column ( y. ncols ( ) - 1 ) . into_owned ( ) ;
704+ let expected = NalgebraVec :: from_vec ( vec ! [ 0.6 , 0.6 ] , * problem. context ( ) ) ;
705+ y_last. assert_eq_norm ( & expected, & problem. atol , problem. rtol , 15.0 ) ;
706+ }
707+
607708 #[ test]
608709 fn test_dense_solve_integrate_out ( ) {
609710 let ( problem, soln) = exponential_decay_problem_adjoint :: < NalgebraMat < f64 > > ( true ) ;
0 commit comments