11use crate :: {
22 error:: DiffsolError , error:: OdeSolverError , ode_solver_error, AugmentedOdeSolverMethod ,
3- Context , DefaultDenseMatrix , DefaultSolver , DenseMatrix , OdeEquationsImplicitSens ,
4- OdeSolverStopReason , Op , SensEquations , VectorViewMut ,
3+ Context , DefaultDenseMatrix , DefaultSolver , DenseMatrix , NonLinearOp , NonLinearOpJacobian ,
4+ NonLinearOpSens , OdeEquationsImplicitSens , OdeSolverStopReason , Op , SensEquations , Vector ,
5+ VectorViewMut ,
56} ;
7+ use num_traits:: { One , Zero } ;
8+ use std:: ops:: AddAssign ;
69
710pub trait SensitivitiesOdeSolverMethod < ' a , Eqn > :
811 AugmentedOdeSolverMethod < ' a , Eqn , SensEquations < ' a , Eqn > >
3639 "Cannot integrate out when solving for sensitivities"
3740 ) ) ;
3841 }
39- let nrows = self . problem ( ) . eqn . rhs ( ) . nstates ( ) ;
42+ let ( mut tmp_nout, mut tmp_nparms, nrows) = if let Some ( out) = self . problem ( ) . eqn . out ( ) {
43+ (
44+ Some ( Eqn :: V :: zeros ( out. nout ( ) , self . problem ( ) . context ( ) . clone ( ) ) ) ,
45+ Some ( Eqn :: V :: zeros (
46+ out. nparams ( ) ,
47+ self . problem ( ) . context ( ) . clone ( ) ,
48+ ) ) ,
49+ out. nout ( ) ,
50+ )
51+ } else {
52+ ( None , None , self . problem ( ) . eqn . rhs ( ) . nout ( ) )
53+ } ;
54+
4055 let mut ret = self
4156 . problem ( )
4257 . context ( )
@@ -62,10 +77,27 @@ where
6277 step_reason = self . step ( ) ?;
6378 }
6479 let y = self . interpolate ( * t) ?;
65- ret. column_mut ( i) . copy_from ( & y) ;
66- let s = self . interpolate_sens ( * t) ?;
67- for ( j, s_j) in s. iter ( ) . enumerate ( ) {
68- ret_sens[ j] . column_mut ( i) . copy_from ( s_j) ;
80+ let mut s = self . interpolate_sens ( * t) ?;
81+ if let Some ( out) = self . problem ( ) . eqn . out ( ) {
82+ let tmp_nout = tmp_nout. as_mut ( ) . unwrap ( ) ;
83+ let tmp_nparams = tmp_nparms. as_mut ( ) . unwrap ( ) ;
84+ out. call_inplace ( & y, * t, tmp_nout) ;
85+ ret. column_mut ( i) . copy_from ( tmp_nout) ;
86+ for ( j, s_j) in s. iter_mut ( ) . enumerate ( ) {
87+ // compute J * s_j + dF/dp * e_j where e_j is the jth basis vector
88+ tmp_nparams. set_index ( j, Eqn :: T :: one ( ) ) ;
89+ out. jac_mul_inplace ( & y, * t, s_j, tmp_nout) ;
90+ s_j. copy_from ( tmp_nout) ;
91+ out. sens_mul_inplace ( & y, * t, tmp_nparams, tmp_nout) ;
92+ s_j. add_assign ( & * tmp_nout) ;
93+ ret_sens[ j] . column_mut ( i) . copy_from ( s_j) ;
94+ tmp_nparams. set_index ( j, Eqn :: T :: zero ( ) ) ;
95+ }
96+ } else {
97+ ret. column_mut ( i) . copy_from ( & y) ;
98+ for ( j, s_j) in s. iter ( ) . enumerate ( ) {
99+ ret_sens[ j] . column_mut ( i) . copy_from ( s_j) ;
100+ }
69101 }
70102 }
71103
@@ -74,11 +106,33 @@ where
74106 step_reason = self . step ( ) ?;
75107 }
76108 let y = self . state ( ) . y ;
77- ret. column_mut ( t_eval. len ( ) - 1 ) . copy_from ( y) ;
78109 let s = self . state ( ) . s ;
79- for ( j, s_j) in s. iter ( ) . enumerate ( ) {
80- ret_sens[ j] . column_mut ( t_eval. len ( ) - 1 ) . copy_from ( s_j) ;
110+ let mut s_tmp = tmp_nout. clone ( ) ;
111+ let i = t_eval. len ( ) - 1 ;
112+ let t = t_eval. last ( ) . unwrap ( ) ;
113+ if let Some ( out) = self . problem ( ) . eqn . out ( ) {
114+ let tmp_nout = tmp_nout. as_mut ( ) . unwrap ( ) ;
115+ let tmp_nparams = tmp_nparms. as_mut ( ) . unwrap ( ) ;
116+ let s_tmp = s_tmp. as_mut ( ) . unwrap ( ) ;
117+ out. call_inplace ( y, * t, tmp_nout) ;
118+ ret. column_mut ( i) . copy_from ( tmp_nout) ;
119+ for ( j, s_j) in s. iter ( ) . enumerate ( ) {
120+ // compute J * s_j + dF/dp * e_j where e_j is the jth basis vector
121+ tmp_nparams. set_index ( j, Eqn :: T :: one ( ) ) ;
122+ out. jac_mul_inplace ( y, * t, s_j, tmp_nout) ;
123+ s_tmp. copy_from ( tmp_nout) ;
124+ out. sens_mul_inplace ( y, * t, tmp_nparams, tmp_nout) ;
125+ s_tmp. add_assign ( & * tmp_nout) ;
126+ ret_sens[ j] . column_mut ( i) . copy_from ( s_tmp) ;
127+ tmp_nparams. set_index ( j, Eqn :: T :: zero ( ) ) ;
128+ }
129+ } else {
130+ ret. column_mut ( i) . copy_from ( y) ;
131+ for ( j, s_j) in s. iter ( ) . enumerate ( ) {
132+ ret_sens[ j] . column_mut ( i) . copy_from ( s_j) ;
133+ }
81134 }
135+
82136 Ok ( ( ret, ret_sens) )
83137 }
84138}
0 commit comments