@@ -9,79 +9,6 @@ use crate::{
99 OdeSolverState , Op , StateRef , StateRefMut , Vector , VectorViewMut ,
1010} ;
1111
12- /// Utility function to write out the solution at a given timepoint
13- /// This function is used by the `solve_dense` method to write out the solution at a given timepoint.
14- fn dense_write_out < ' a , Eqn : OdeEquations + ' a , S : OdeSolverMethod < ' a , Eqn > > (
15- s : & S ,
16- y_out : & mut <Eqn :: V as DefaultDenseMatrix >:: M ,
17- t_eval : & [ Eqn :: T ] ,
18- i : usize ,
19- ) -> Result < ( ) , DiffsolError >
20- where
21- Eqn :: V : DefaultDenseMatrix ,
22- {
23- let mut y_out = y_out. column_mut ( i) ;
24- let t = t_eval[ i] ;
25- if s. problem ( ) . integrate_out {
26- let g = s. interpolate_out ( t) ?;
27- y_out. copy_from ( & g) ;
28- } else {
29- let y = s. interpolate ( t) ?;
30- match s. problem ( ) . eqn . out ( ) {
31- Some ( out) => y_out. copy_from ( & out. call ( & y, t_eval[ i] ) ) ,
32- None => y_out. copy_from ( & y) ,
33- }
34- }
35- Ok ( ( ) )
36- }
37-
38- /// utility function to write out the solution at a given timepoint
39- /// This function is used by the `solve` method to write out the solution at a given timepoint.
40- fn write_out < ' a , Eqn : OdeEquations + ' a , S : OdeSolverMethod < ' a , Eqn > > (
41- s : & S ,
42- ret_y : & mut Vec < Eqn :: V > ,
43- ret_t : & mut Vec < Eqn :: T > ,
44- ) {
45- let t = s. state ( ) . t ;
46- let y = s. state ( ) . y ;
47- ret_t. push ( t) ;
48- match s. problem ( ) . eqn . out ( ) {
49- Some ( out) => {
50- if s. problem ( ) . integrate_out {
51- ret_y. push ( s. state ( ) . g . clone ( ) ) ;
52- } else {
53- ret_y. push ( out. call ( y, t) ) ;
54- }
55- }
56- None => ret_y. push ( y. clone ( ) ) ,
57- }
58- }
59-
60- fn dense_allocate_return < ' a , Eqn : OdeEquations + ' a , S : OdeSolverMethod < ' a , Eqn > > (
61- s : & S ,
62- t_eval : & [ Eqn :: T ] ,
63- ) -> Result < <Eqn :: V as DefaultDenseMatrix >:: M , DiffsolError >
64- where
65- Eqn :: V : DefaultDenseMatrix ,
66- {
67- let nrows = if s. problem ( ) . eqn . out ( ) . is_some ( ) {
68- s. problem ( ) . eqn . out ( ) . unwrap ( ) . nout ( )
69- } else {
70- s. problem ( ) . eqn . rhs ( ) . nstates ( )
71- } ;
72- let ret = s
73- . problem ( )
74- . context ( )
75- . dense_mat_zeros :: < Eqn :: V > ( nrows, t_eval. len ( ) ) ;
76-
77- // check t_eval is increasing and all values are greater than or equal to the current time
78- let t0 = s. state ( ) . t ;
79- if t_eval. windows ( 2 ) . any ( |w| w[ 0 ] > w[ 1 ] || w[ 0 ] < t0) {
80- return Err ( ode_solver_error ! ( InvalidTEval ) ) ;
81- }
82- Ok ( ret)
83- }
84-
8512#[ derive( Debug , PartialEq ) ]
8613pub enum OdeSolverStopReason < T : Scalar > {
8714 InternalTimestep ,
@@ -400,9 +327,85 @@ where
400327 fn augmented_eqn ( & self ) -> Option < & AugmentedEqn > ;
401328}
402329
330+ /// Utility function to write out the solution at a given timepoint
331+ /// This function is used by the `solve_dense` method to write out the solution at a given timepoint.
332+ fn dense_write_out < ' a , Eqn : OdeEquations + ' a , S : OdeSolverMethod < ' a , Eqn > > (
333+ s : & S ,
334+ y_out : & mut <Eqn :: V as DefaultDenseMatrix >:: M ,
335+ t_eval : & [ Eqn :: T ] ,
336+ i : usize ,
337+ ) -> Result < ( ) , DiffsolError >
338+ where
339+ Eqn :: V : DefaultDenseMatrix ,
340+ {
341+ let mut y_out = y_out. column_mut ( i) ;
342+ let t = t_eval[ i] ;
343+ if s. problem ( ) . integrate_out {
344+ let g = s. interpolate_out ( t) ?;
345+ y_out. copy_from ( & g) ;
346+ } else {
347+ let y = s. interpolate ( t) ?;
348+ match s. problem ( ) . eqn . out ( ) {
349+ Some ( out) => y_out. copy_from ( & out. call ( & y, t_eval[ i] ) ) ,
350+ None => y_out. copy_from ( & y) ,
351+ }
352+ }
353+ Ok ( ( ) )
354+ }
355+
356+ /// utility function to write out the solution at a given timepoint
357+ /// This function is used by the `solve` method to write out the solution at a given timepoint.
358+ fn write_out < ' a , Eqn : OdeEquations + ' a , S : OdeSolverMethod < ' a , Eqn > > (
359+ s : & S ,
360+ ret_y : & mut Vec < Eqn :: V > ,
361+ ret_t : & mut Vec < Eqn :: T > ,
362+ ) {
363+ let t = s. state ( ) . t ;
364+ let y = s. state ( ) . y ;
365+ ret_t. push ( t) ;
366+ match s. problem ( ) . eqn . out ( ) {
367+ Some ( out) => {
368+ if s. problem ( ) . integrate_out {
369+ ret_y. push ( s. state ( ) . g . clone ( ) ) ;
370+ } else {
371+ ret_y. push ( out. call ( y, t) ) ;
372+ }
373+ }
374+ None => ret_y. push ( y. clone ( ) ) ,
375+ }
376+ }
377+
378+ /// Utility function to allocate the return matrix for the `solve_dense`
379+ /// and `solve_dense_sensitivities` methods.
380+ fn dense_allocate_return < ' a , Eqn : OdeEquations + ' a , S : OdeSolverMethod < ' a , Eqn > > (
381+ s : & S ,
382+ t_eval : & [ Eqn :: T ] ,
383+ ) -> Result < <Eqn :: V as DefaultDenseMatrix >:: M , DiffsolError >
384+ where
385+ Eqn :: V : DefaultDenseMatrix ,
386+ {
387+ let nrows = if s. problem ( ) . eqn . out ( ) . is_some ( ) {
388+ s. problem ( ) . eqn . out ( ) . unwrap ( ) . nout ( )
389+ } else {
390+ s. problem ( ) . eqn . rhs ( ) . nstates ( )
391+ } ;
392+ let ret = s
393+ . problem ( )
394+ . context ( )
395+ . dense_mat_zeros :: < Eqn :: V > ( nrows, t_eval. len ( ) ) ;
396+
397+ // check t_eval is increasing and all values are greater than or equal to the current time
398+ let t0 = s. state ( ) . t ;
399+ if t_eval. windows ( 2 ) . any ( |w| w[ 0 ] > w[ 1 ] || w[ 0 ] < t0) {
400+ return Err ( ode_solver_error ! ( InvalidTEval ) ) ;
401+ }
402+ Ok ( ret)
403+ }
404+
403405#[ cfg( test) ]
404406mod test {
405407 use crate :: {
408+ error:: { DiffsolError , OdeSolverError } ,
406409 matrix:: dense_nalgebra_serial:: NalgebraMat ,
407410 ode_equations:: test_models:: exponential_decay:: {
408411 exponential_decay_problem, exponential_decay_problem_adjoint,
@@ -477,6 +480,18 @@ mod test {
477480 }
478481 }
479482
483+ #[ test]
484+ fn test_t_eval_errors ( ) {
485+ let ( problem, _soln) = exponential_decay_problem :: < NalgebraMat < f64 > > ( false ) ;
486+ let mut s = problem. bdf :: < NalgebraLU < f64 > > ( ) . unwrap ( ) ;
487+ let t_eval = vec ! [ 0.0 , 1.0 , 0.5 , 2.0 ] ;
488+ let err = s. solve_dense ( t_eval. as_slice ( ) ) . unwrap_err ( ) ;
489+ assert ! ( matches!(
490+ err,
491+ DiffsolError :: OdeSolverError ( OdeSolverError :: InvalidTEval )
492+ ) ) ;
493+ }
494+
480495 #[ test]
481496 fn test_dense_solve_sensitivities ( ) {
482497 let ( problem, soln) = exponential_decay_problem_sens :: < NalgebraMat < f64 > > ( false ) ;
0 commit comments