@@ -13,11 +13,12 @@ use serde::Serialize;
1313use crate :: ode_solver_error;
1414use crate :: {
1515 matrix:: MatrixRef , nonlinear_solver:: root:: RootFinder , op:: bdf:: BdfCallable , scalar:: scale,
16- AugmentedOdeEquations , BdfState , DenseMatrix , IndexType , JacobianUpdate , MatrixViewMut ,
17- NonLinearOp , NonLinearSolver , OdeEquationsImplicit , OdeSolverMethod , OdeSolverProblem ,
18- OdeSolverState , OdeSolverStopReason , Op , Scalar , Vector , VectorRef , VectorView , VectorViewMut ,
16+ AugmentedOdeEquations , BdfState , DenseMatrix , JacobianUpdate , MatrixViewMut , NonLinearOp ,
17+ NonLinearSolver , OdeEquationsImplicit , OdeSolverMethod , OdeSolverProblem , OdeSolverState ,
18+ OdeSolverStopReason , Op , Scalar , Vector , VectorRef , VectorView , VectorViewMut ,
1919} ;
2020
21+ use super :: config:: BdfConfig ;
2122use super :: jacobian_update:: SolverState ;
2223use super :: method:: AugmentedOdeSolverMethod ;
2324
@@ -110,6 +111,7 @@ pub struct Bdf<
110111 root_finder : Option < RootFinder < Eqn :: V > > ,
111112 is_state_modified : bool ,
112113 jacobian_update : JacobianUpdate < Eqn :: T > ,
114+ config : BdfConfig < Eqn :: T > ,
113115}
114116
115117impl < M , Eqn , Nls , AugmentedEqn > Clone for Bdf < ' _ , Eqn , Nls , M , AugmentedEqn >
@@ -162,6 +164,7 @@ where
162164 root_finder : self . root_finder . clone ( ) ,
163165 is_state_modified : self . is_state_modified ,
164166 jacobian_update : self . jacobian_update . clone ( ) ,
167+ config : self . config . clone ( ) ,
165168 }
166169 }
167170}
@@ -176,27 +179,20 @@ where
176179 for < ' b > & ' b Eqn :: M : MatrixRef < Eqn :: M > ,
177180 Nls : NonLinearSolver < Eqn :: M > ,
178181{
179- const NEWTON_MAXITER : IndexType = 4 ;
180- const MIN_FACTOR : f64 = 0.5 ;
181- const MAX_FACTOR : f64 = 2.1 ;
182- const MAX_THRESHOLD : f64 = 2.0 ;
183- const MIN_THRESHOLD : f64 = 0.9 ;
184- const MIN_TIMESTEP : f64 = 1e-32 ;
185- const MAX_ERROR_TEST_FAILS : usize = 40 ;
186-
187182 pub fn new (
188183 problem : & ' a OdeSolverProblem < Eqn > ,
189184 state : BdfState < Eqn :: V , M > ,
190185 nonlinear_solver : Nls ,
191186 ) -> Result < Self , DiffsolError > {
192- Self :: _new ( problem, state, nonlinear_solver, true )
187+ Self :: _new ( problem, state, nonlinear_solver, true , BdfConfig :: default ( ) )
193188 }
194189
195190 fn _new (
196191 problem : & ' a OdeSolverProblem < Eqn > ,
197192 mut state : BdfState < Eqn :: V , M > ,
198193 mut nonlinear_solver : Nls ,
199194 integrate_main_eqn : bool ,
195+ config : BdfConfig < Eqn :: T > ,
200196 ) -> Result < Self , DiffsolError > {
201197 // kappa values for difference orders, taken from Table 1 of [1]
202198 let kappa = [
@@ -226,7 +222,7 @@ where
226222 state. check_consistent_with_problem ( problem) ?;
227223
228224 let mut convergence = Convergence :: new ( problem. rtol , & problem. atol ) ;
229- convergence. set_max_iter ( Self :: NEWTON_MAXITER ) ;
225+ convergence. set_max_iter ( config . maximum_newton_iterations ) ;
230226
231227 let op = if integrate_main_eqn {
232228 // setup linear solver for first step
@@ -297,6 +293,7 @@ where
297293 root_finder,
298294 is_state_modified,
299295 jacobian_update : JacobianUpdate :: default ( ) ,
296+ config,
300297 } )
301298 }
302299
@@ -305,6 +302,22 @@ where
305302 problem : & ' a OdeSolverProblem < Eqn > ,
306303 augmented_eqn : AugmentedEqn ,
307304 nonlinear_solver : Nls ,
305+ ) -> Result < Self , DiffsolError > {
306+ Self :: new_augmented_with_config (
307+ state,
308+ problem,
309+ augmented_eqn,
310+ nonlinear_solver,
311+ BdfConfig :: default ( ) ,
312+ )
313+ }
314+
315+ pub fn new_augmented_with_config (
316+ state : BdfState < Eqn :: V , M > ,
317+ problem : & ' a OdeSolverProblem < Eqn > ,
318+ augmented_eqn : AugmentedEqn ,
319+ nonlinear_solver : Nls ,
320+ config : BdfConfig < Eqn :: T > ,
308321 ) -> Result < Self , DiffsolError > {
309322 state. check_sens_consistent_with_problem ( problem, & augmented_eqn) ?;
310323
@@ -313,6 +326,7 @@ where
313326 state,
314327 nonlinear_solver,
315328 augmented_eqn. integrate_main_eqn ( ) ,
329+ config,
316330 ) ?;
317331
318332 ret. state . set_augmented_problem ( problem, & augmented_eqn) ?;
@@ -456,7 +470,7 @@ where
456470 self . state . h = new_h;
457471
458472 // if step size too small, then fail
459- if self . state . h . abs ( ) < Eqn :: T :: from ( Self :: MIN_TIMESTEP ) {
473+ if self . state . h . abs ( ) < self . config . minimum_timestep {
460474 return Err ( DiffsolError :: from ( OdeSolverError :: StepSizeTooSmall {
461475 time : self . state . t . into ( ) ,
462476 } ) ) ;
@@ -860,6 +874,15 @@ where
860874 for < ' b > & ' b Eqn :: M : MatrixRef < Eqn :: M > ,
861875{
862876 type State = BdfState < Eqn :: V , M > ;
877+ type Config = BdfConfig < Eqn :: T > ;
878+
879+ fn config ( & self ) -> & BdfConfig < Eqn :: T > {
880+ & self . config
881+ }
882+
883+ fn config_mut ( & mut self ) -> & mut BdfConfig < Eqn :: T > {
884+ & mut self . config
885+ }
863886
864887 fn order ( & self ) -> usize {
865888 self . state . order
@@ -1126,8 +1149,8 @@ where
11261149 // calculate optimal step size factor as per eq 2.46 of [2]
11271150 // and reduce step size and try again
11281151 let mut factor = safety * error_norm. pow ( Eqn :: T :: from ( -0.5 / ( order as f64 + 1.0 ) ) ) ;
1129- if factor < Eqn :: T :: from ( Self :: MIN_FACTOR ) {
1130- factor = Eqn :: T :: from ( Self :: MIN_FACTOR ) ;
1152+ if factor < self . config . minimum_timestep_shrink {
1153+ factor = self . config . minimum_timestep_shrink ;
11311154 }
11321155 let new_h = self . _update_step_size ( factor) ?;
11331156 self . _jacobian_updates ( new_h * self . alpha [ order] , SolverState :: ErrorTestFail ) ;
@@ -1138,7 +1161,7 @@ where
11381161 // update statistics
11391162 self . statistics . number_of_error_test_failures += 1 ;
11401163 if self . statistics . number_of_error_test_failures - old_num_error_test_failures
1141- >= Self :: MAX_ERROR_TEST_FAILS
1164+ >= self . config . maximum_error_test_failures
11421165 {
11431166 return Err ( DiffsolError :: from (
11441167 OdeSolverError :: TooManyErrorTestFailures {
@@ -1230,14 +1253,14 @@ where
12301253 } ;
12311254
12321255 let mut factor = safety * factors[ max_index] ;
1233- if factor > Eqn :: T :: from ( Self :: MAX_FACTOR ) {
1234- factor = Eqn :: T :: from ( Self :: MAX_FACTOR ) ;
1256+ if factor > self . config . maximum_timestep_growth {
1257+ factor = self . config . maximum_timestep_growth ;
12351258 }
1236- if factor < Eqn :: T :: from ( Self :: MIN_FACTOR ) {
1237- factor = Eqn :: T :: from ( Self :: MIN_FACTOR ) ;
1259+ if factor < self . config . minimum_timestep_shrink {
1260+ factor = self . config . minimum_timestep_shrink ;
12381261 }
1239- if factor >= Eqn :: T :: from ( Self :: MAX_THRESHOLD )
1240- || factor < Eqn :: T :: from ( Self :: MIN_THRESHOLD )
1262+ if factor >= self . config . minimum_timestep_growth
1263+ || factor < self . config . maximum_timestep_shrink
12411264 || max_index == 0
12421265 || max_index == 2
12431266 {
@@ -1305,8 +1328,8 @@ mod test {
13051328 } ,
13061329 ode_solver:: tests:: {
13071330 setup_test_adjoint, setup_test_adjoint_sum_squares, test_adjoint,
1308- test_adjoint_sum_squares, test_checkpointing, test_interpolate , test_ode_solver ,
1309- test_problem, test_state_mut, test_state_mut_on_problem,
1331+ test_adjoint_sum_squares, test_checkpointing, test_config , test_interpolate ,
1332+ test_ode_solver , test_problem, test_state_mut, test_state_mut_on_problem,
13101333 } ,
13111334 Context , DenseMatrix , FaerLU , FaerMat , FaerSparseLU , FaerSparseMat , MatrixCommon ,
13121335 OdeEquations , OdeSolverMethod , Op , Vector , VectorView ,
@@ -1321,6 +1344,11 @@ mod test {
13211344 test_state_mut ( test_problem :: < M > ( ) . bdf :: < LS > ( ) . unwrap ( ) ) ;
13221345 }
13231346
1347+ #[ test]
1348+ fn bdf_config ( ) {
1349+ test_config ( robertson_ode :: < M > ( false , 1 ) . 0 . bdf :: < LS > ( ) . unwrap ( ) ) ;
1350+ }
1351+
13241352 #[ test]
13251353 fn bdf_test_interpolate ( ) {
13261354 test_interpolate ( test_problem :: < M > ( ) . bdf :: < LS > ( ) . unwrap ( ) ) ;
0 commit comments