@@ -5,9 +5,10 @@ use crate::{
55 ode_solver_error,
66 scalar:: Scalar ,
77 AugmentedOdeEquations , Checkpointing , Context , DefaultDenseMatrix , DenseMatrix ,
8- HermiteInterpolator , NonLinearOp , OdeEquations , OdeSolverConfig , OdeSolverProblem ,
9- OdeSolverState , Op , StateRef , StateRefMut , Vector , VectorViewMut ,
8+ HermiteInterpolator , MatrixCommon , NonLinearOp , OdeEquations , OdeSolverConfig ,
9+ OdeSolverProblem , OdeSolverState , Op , StateRef , StateRefMut , Vector , VectorViewMut ,
1010} ;
11+ use nalgebra:: ComplexField ;
1112
1213#[ derive( Debug , PartialEq ) ]
1314pub enum OdeSolverStopReason < T : Scalar > {
@@ -120,27 +121,20 @@ where
120121 Self : Sized ,
121122 {
122123 let mut ret_t = Vec :: new ( ) ;
123- let mut ret_y = Vec :: new ( ) ;
124+ let ( mut ret_y, mut tmp_nout ) = allocate_return ( self ) ? ;
124125
125126 // do the main loop
126- write_out ( self , & mut ret_y, & mut ret_t) ;
127+ write_out ( self , & mut ret_y, & mut ret_t, final_time , & mut tmp_nout ) ;
127128 self . set_stop_time ( final_time) ?;
128129 while self . step ( ) ? != OdeSolverStopReason :: TstopReached {
129- write_out ( self , & mut ret_y, & mut ret_t) ;
130+ write_out ( self , & mut ret_y, & mut ret_t, final_time , & mut tmp_nout ) ;
130131 }
131132
132133 // store the final step
133- write_out ( self , & mut ret_y, & mut ret_t) ;
134+ write_out ( self , & mut ret_y, & mut ret_t, final_time , & mut tmp_nout ) ;
134135 let ntimes = ret_t. len ( ) ;
135- let nrows = ret_y[ 0 ] . len ( ) ;
136- let mut ret_y_matrix = self
137- . problem ( )
138- . context ( )
139- . dense_mat_zeros :: < Eqn :: V > ( nrows, ntimes) ;
140- for ( i, y) in ret_y. iter ( ) . enumerate ( ) {
141- ret_y_matrix. column_mut ( i) . copy_from ( y) ;
142- }
143- Ok ( ( ret_y_matrix, ret_t) )
136+ ret_y. resize_cols ( ntimes) ;
137+ Ok ( ( ret_y, ret_t) )
144138 }
145139
146140 /// Using the provided state, solve the problem up to time `t_eval[t_eval.len()-1]`
@@ -154,7 +148,7 @@ where
154148 Eqn :: V : DefaultDenseMatrix ,
155149 Self : Sized ,
156150 {
157- let mut ret = dense_allocate_return ( self , t_eval) ?;
151+ let ( mut ret, mut tmp_nout ) = dense_allocate_return ( self , t_eval) ?;
158152
159153 // do loop
160154 self . set_stop_time ( t_eval[ t_eval. len ( ) - 1 ] ) ?;
@@ -163,7 +157,7 @@ where
163157 while self . state ( ) . t < * t {
164158 step_reason = self . step ( ) ?;
165159 }
166- dense_write_out ( self , & mut ret, t_eval, i) ?;
160+ dense_write_out ( self , & mut ret, t_eval, i, & mut tmp_nout ) ?;
167161 }
168162 assert_eq ! ( step_reason, OdeSolverStopReason :: TstopReached ) ;
169163 Ok ( ret)
@@ -187,7 +181,7 @@ where
187181 Self : Sized ,
188182 {
189183 let mut ret_t = Vec :: new ( ) ;
190- let mut ret_y = Vec :: new ( ) ;
184+ let ( mut ret_y, mut tmp_nout ) = allocate_return ( self ) ? ;
191185 let max_steps_between_checkpoints = max_steps_between_checkpoints. unwrap_or ( 500 ) ;
192186
193187 // allocate checkpoint info
@@ -199,10 +193,10 @@ where
199193 let mut ydots = vec ! [ self . state( ) . dy. clone( ) ] ;
200194
201195 // do the main loop, saving checkpoints
202- write_out ( self , & mut ret_y, & mut ret_t) ;
196+ write_out ( self , & mut ret_y, & mut ret_t, final_time , & mut tmp_nout ) ;
203197 self . set_stop_time ( final_time) ?;
204198 while self . step ( ) ? != OdeSolverStopReason :: TstopReached {
205- write_out ( self , & mut ret_y, & mut ret_t) ;
199+ write_out ( self , & mut ret_y, & mut ret_t, final_time , & mut tmp_nout ) ;
206200 ts. push ( self . state ( ) . t ) ;
207201 ys. push ( self . state ( ) . y . clone ( ) ) ;
208202 ydots. push ( self . state ( ) . dy . clone ( ) ) ;
@@ -217,16 +211,9 @@ where
217211 }
218212
219213 // store the final step
220- write_out ( self , & mut ret_y, & mut ret_t) ;
214+ write_out ( self , & mut ret_y, & mut ret_t, final_time , & mut tmp_nout ) ;
221215 let ntimes = ret_t. len ( ) ;
222- let nrows = ret_y[ 0 ] . len ( ) ;
223- let mut ret_y_matrix = self
224- . problem ( )
225- . context ( )
226- . dense_mat_zeros :: < Eqn :: V > ( nrows, ntimes) ;
227- for ( i, y) in ret_y. iter ( ) . enumerate ( ) {
228- ret_y_matrix. column_mut ( i) . copy_from ( y) ;
229- }
216+ ret_y. resize_cols ( ntimes) ;
230217
231218 // add final checkpoint
232219 ts. push ( self . state ( ) . t ) ;
@@ -243,7 +230,7 @@ where
243230 Some ( last_segment) ,
244231 ) ;
245232
246- Ok ( ( checkpointer, ret_y_matrix , ret_t) )
233+ Ok ( ( checkpointer, ret_y , ret_t) )
247234 }
248235
249236 /// Solve the problem and write out the solution at the given timepoints, using checkpointing so that
@@ -265,7 +252,7 @@ where
265252 Eqn :: V : DefaultDenseMatrix ,
266253 Self : Sized ,
267254 {
268- let mut ret = dense_allocate_return ( self , t_eval) ?;
255+ let ( mut ret, mut tmp_nout ) = dense_allocate_return ( self , t_eval) ?;
269256 let max_steps_between_checkpoints = max_steps_between_checkpoints. unwrap_or ( 500 ) ;
270257
271258 // allocate checkpoint info
@@ -296,7 +283,7 @@ where
296283 ydots. clear ( ) ;
297284 }
298285 }
299- dense_write_out ( self , & mut ret, t_eval, i) ?;
286+ dense_write_out ( self , & mut ret, t_eval, i, & mut tmp_nout ) ?;
300287 }
301288 assert_eq ! ( step_reason, OdeSolverStopReason :: TstopReached ) ;
302289
@@ -334,6 +321,7 @@ fn dense_write_out<'a, Eqn: OdeEquations + 'a, S: OdeSolverMethod<'a, Eqn>>(
334321 y_out : & mut <Eqn :: V as DefaultDenseMatrix >:: M ,
335322 t_eval : & [ Eqn :: T ] ,
336323 i : usize ,
324+ tmp_nout : & mut Eqn :: V ,
337325) -> Result < ( ) , DiffsolError >
338326where
339327 Eqn :: V : DefaultDenseMatrix ,
@@ -346,7 +334,10 @@ where
346334 } else {
347335 let y = s. interpolate ( t) ?;
348336 match s. problem ( ) . eqn . out ( ) {
349- Some ( out) => y_out. copy_from ( & out. call ( & y, t_eval[ i] ) ) ,
337+ Some ( out) => {
338+ out. call_inplace ( & y, t_eval[ i] , tmp_nout) ;
339+ y_out. copy_from ( tmp_nout)
340+ }
350341 None => y_out. copy_from ( & y) ,
351342 }
352343 }
@@ -357,30 +348,74 @@ where
357348/// This function is used by the `solve` method to write out the solution at a given timepoint.
358349fn write_out < ' a , Eqn : OdeEquations + ' a , S : OdeSolverMethod < ' a , Eqn > > (
359350 s : & S ,
360- ret_y : & mut Vec < Eqn :: V > ,
351+ ret_y : & mut <Eqn :: V as DefaultDenseMatrix > :: M ,
361352 ret_t : & mut Vec < Eqn :: T > ,
362- ) {
353+ final_time : Eqn :: T ,
354+ tmp_nout : & mut Eqn :: V ,
355+ ) where
356+ Eqn :: V : DefaultDenseMatrix ,
357+ {
363358 let t = s. state ( ) . t ;
364359 let y = s. state ( ) . y ;
365360 ret_t. push ( t) ;
361+ let i = ret_t. len ( ) - 1 ;
362+ if i >= ret_y. ncols ( ) {
363+ const GROWTH_FACTOR : f64 = 1.5 ;
364+ let remaining: f64 = ( Eqn :: T :: from ( GROWTH_FACTOR ) * ( final_time - ret_t[ i - 1 ] )
365+ / ( ret_t[ i] - ret_t[ i - 1 ] ) )
366+ . ceil ( )
367+ . into ( ) ;
368+ let n = ret_y. ncols ( ) + ( remaining as usize ) ;
369+ ret_y. resize_cols ( n) ;
370+ }
371+ let mut ret_y_col = ret_y. column_mut ( i) ;
366372 match s. problem ( ) . eqn . out ( ) {
367373 Some ( out) => {
368374 if s. problem ( ) . integrate_out {
369- ret_y . push ( s. state ( ) . g . clone ( ) ) ;
375+ ret_y_col . copy_from ( s. state ( ) . g ) ;
370376 } else {
371- ret_y. push ( out. call ( y, t) ) ;
377+ out. call_inplace ( y, t, tmp_nout) ;
378+ ret_y_col. copy_from ( tmp_nout) ;
372379 }
373380 }
374- None => ret_y . push ( y . clone ( ) ) ,
381+ None => ret_y_col . copy_from ( y ) ,
375382 }
376383}
377384
385+ /// Utility function to allocate the return matrix for the `solve`
386+ /// method
387+ fn allocate_return < ' a , Eqn : OdeEquations + ' a , S : OdeSolverMethod < ' a , Eqn > > (
388+ s : & S ,
389+ ) -> Result < ( <Eqn :: V as DefaultDenseMatrix >:: M , Eqn :: V ) , DiffsolError >
390+ where
391+ Eqn :: V : DefaultDenseMatrix ,
392+ {
393+ let nrows = if s. problem ( ) . eqn . out ( ) . is_some ( ) {
394+ s. problem ( ) . eqn . out ( ) . unwrap ( ) . nout ( )
395+ } else {
396+ s. problem ( ) . eqn . rhs ( ) . nstates ( )
397+ } ;
398+ const INITIAL_NCOLS : usize = 10 ;
399+ let ret = s
400+ . problem ( )
401+ . context ( )
402+ . dense_mat_zeros :: < Eqn :: V > ( nrows, INITIAL_NCOLS ) ;
403+
404+ // check t_eval is increasing and all values are greater than or equal to the current time
405+ let tmp_nout = if let Some ( out) = s. problem ( ) . eqn . out ( ) {
406+ Eqn :: V :: zeros ( out. nout ( ) , s. problem ( ) . context ( ) . clone ( ) )
407+ } else {
408+ Eqn :: V :: zeros ( 0 , s. problem ( ) . context ( ) . clone ( ) )
409+ } ;
410+ Ok ( ( ret, tmp_nout) )
411+ }
412+
378413/// Utility function to allocate the return matrix for the `solve_dense`
379414/// and `solve_dense_sensitivities` methods.
380415fn dense_allocate_return < ' a , Eqn : OdeEquations + ' a , S : OdeSolverMethod < ' a , Eqn > > (
381416 s : & S ,
382417 t_eval : & [ Eqn :: T ] ,
383- ) -> Result < <Eqn :: V as DefaultDenseMatrix >:: M , DiffsolError >
418+ ) -> Result < ( <Eqn :: V as DefaultDenseMatrix >:: M , Eqn :: V ) , DiffsolError >
384419where
385420 Eqn :: V : DefaultDenseMatrix ,
386421{
@@ -399,7 +434,12 @@ where
399434 if t_eval. windows ( 2 ) . any ( |w| w[ 0 ] > w[ 1 ] || w[ 0 ] < t0) {
400435 return Err ( ode_solver_error ! ( InvalidTEval ) ) ;
401436 }
402- Ok ( ret)
437+ let tmp_nout = if let Some ( out) = s. problem ( ) . eqn . out ( ) {
438+ Eqn :: V :: zeros ( out. nout ( ) , s. problem ( ) . context ( ) . clone ( ) )
439+ } else {
440+ Eqn :: V :: zeros ( 0 , s. problem ( ) . context ( ) . clone ( ) )
441+ } ;
442+ Ok ( ( ret, tmp_nout) )
403443}
404444
405445#[ cfg( test) ]
0 commit comments