Skip to content

Commit 68b066f

Browse files
feat: add config structs for each solver (#181)
* feat: add config structs for each solver * finalise it * add some tests
1 parent 717963c commit 68b066f

File tree

8 files changed

+348
-70
lines changed

8 files changed

+348
-70
lines changed

diffsol/src/lib.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,26 @@ use ode_solver::jacobian_update::JacobianUpdate;
196196
pub use ode_solver::sde::SdeSolverMethod;
197197
pub use ode_solver::state::{StateRef, StateRefMut};
198198
pub use ode_solver::{
199-
adjoint::AdjointOdeSolverMethod, bdf::Bdf, bdf_state::BdfState, builder::OdeBuilder,
200-
checkpointing::Checkpointing, checkpointing::HermiteInterpolator, explicit_rk::ExplicitRk,
201-
method::AugmentedOdeSolverMethod, method::OdeSolverMethod, method::OdeSolverStopReason,
202-
problem::OdeSolverProblem, sdirk::Sdirk, sdirk_state::RkState,
203-
sensitivities::SensitivitiesOdeSolverMethod, state::OdeSolverState, tableau::Tableau,
199+
adjoint::AdjointOdeSolverMethod,
200+
bdf::Bdf,
201+
bdf_state::BdfState,
202+
builder::OdeBuilder,
203+
checkpointing::Checkpointing,
204+
checkpointing::HermiteInterpolator,
205+
config::{
206+
BdfConfig, ExplicitRkConfig, OdeSolverConfig, OdeSolverConfigMut, OdeSolverConfigRef,
207+
SdirkConfig,
208+
},
209+
explicit_rk::ExplicitRk,
210+
method::AugmentedOdeSolverMethod,
211+
method::OdeSolverMethod,
212+
method::OdeSolverStopReason,
213+
problem::OdeSolverProblem,
214+
sdirk::Sdirk,
215+
sdirk_state::RkState,
216+
sensitivities::SensitivitiesOdeSolverMethod,
217+
state::OdeSolverState,
218+
tableau::Tableau,
204219
};
205220
pub use op::constant_op::{ConstantOp, ConstantOpSens, ConstantOpSensAdjoint};
206221
pub use op::linear_op::{LinearOp, LinearOpSens, LinearOpTranspose};

diffsol/src/ode_solver/bdf.rs

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@ use serde::Serialize;
1313
use crate::ode_solver_error;
1414
use crate::{
1515
matrix::MatrixRef, nonlinear_solver::root::RootFinder, op::bdf::BdfCallable, scalar::scale,
16-
AugmentedOdeEquations, BdfState, DenseMatrix, IndexType, JacobianUpdate, MatrixViewMut,
17-
NonLinearOp, NonLinearSolver, OdeEquationsImplicit, OdeSolverMethod, OdeSolverProblem,
18-
OdeSolverState, OdeSolverStopReason, Op, Scalar, Vector, VectorRef, VectorView, VectorViewMut,
16+
AugmentedOdeEquations, BdfState, DenseMatrix, JacobianUpdate, MatrixViewMut, NonLinearOp,
17+
NonLinearSolver, OdeEquationsImplicit, OdeSolverMethod, OdeSolverProblem, OdeSolverState,
18+
OdeSolverStopReason, Op, Scalar, Vector, VectorRef, VectorView, VectorViewMut,
1919
};
2020

21+
use super::config::BdfConfig;
2122
use super::jacobian_update::SolverState;
2223
use super::method::AugmentedOdeSolverMethod;
2324

@@ -110,6 +111,7 @@ pub struct Bdf<
110111
root_finder: Option<RootFinder<Eqn::V>>,
111112
is_state_modified: bool,
112113
jacobian_update: JacobianUpdate<Eqn::T>,
114+
config: BdfConfig<Eqn::T>,
113115
}
114116

115117
impl<M, Eqn, Nls, AugmentedEqn> Clone for Bdf<'_, Eqn, Nls, M, AugmentedEqn>
@@ -162,6 +164,7 @@ where
162164
root_finder: self.root_finder.clone(),
163165
is_state_modified: self.is_state_modified,
164166
jacobian_update: self.jacobian_update.clone(),
167+
config: self.config.clone(),
165168
}
166169
}
167170
}
@@ -176,27 +179,20 @@ where
176179
for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
177180
Nls: NonLinearSolver<Eqn::M>,
178181
{
179-
const NEWTON_MAXITER: IndexType = 4;
180-
const MIN_FACTOR: f64 = 0.5;
181-
const MAX_FACTOR: f64 = 2.1;
182-
const MAX_THRESHOLD: f64 = 2.0;
183-
const MIN_THRESHOLD: f64 = 0.9;
184-
const MIN_TIMESTEP: f64 = 1e-32;
185-
const MAX_ERROR_TEST_FAILS: usize = 40;
186-
187182
pub fn new(
188183
problem: &'a OdeSolverProblem<Eqn>,
189184
state: BdfState<Eqn::V, M>,
190185
nonlinear_solver: Nls,
191186
) -> Result<Self, DiffsolError> {
192-
Self::_new(problem, state, nonlinear_solver, true)
187+
Self::_new(problem, state, nonlinear_solver, true, BdfConfig::default())
193188
}
194189

195190
fn _new(
196191
problem: &'a OdeSolverProblem<Eqn>,
197192
mut state: BdfState<Eqn::V, M>,
198193
mut nonlinear_solver: Nls,
199194
integrate_main_eqn: bool,
195+
config: BdfConfig<Eqn::T>,
200196
) -> Result<Self, DiffsolError> {
201197
// kappa values for difference orders, taken from Table 1 of [1]
202198
let kappa = [
@@ -226,7 +222,7 @@ where
226222
state.check_consistent_with_problem(problem)?;
227223

228224
let mut convergence = Convergence::new(problem.rtol, &problem.atol);
229-
convergence.set_max_iter(Self::NEWTON_MAXITER);
225+
convergence.set_max_iter(config.maximum_newton_iterations);
230226

231227
let op = if integrate_main_eqn {
232228
// setup linear solver for first step
@@ -297,6 +293,7 @@ where
297293
root_finder,
298294
is_state_modified,
299295
jacobian_update: JacobianUpdate::default(),
296+
config,
300297
})
301298
}
302299

@@ -305,6 +302,22 @@ where
305302
problem: &'a OdeSolverProblem<Eqn>,
306303
augmented_eqn: AugmentedEqn,
307304
nonlinear_solver: Nls,
305+
) -> Result<Self, DiffsolError> {
306+
Self::new_augmented_with_config(
307+
state,
308+
problem,
309+
augmented_eqn,
310+
nonlinear_solver,
311+
BdfConfig::default(),
312+
)
313+
}
314+
315+
pub fn new_augmented_with_config(
316+
state: BdfState<Eqn::V, M>,
317+
problem: &'a OdeSolverProblem<Eqn>,
318+
augmented_eqn: AugmentedEqn,
319+
nonlinear_solver: Nls,
320+
config: BdfConfig<Eqn::T>,
308321
) -> Result<Self, DiffsolError> {
309322
state.check_sens_consistent_with_problem(problem, &augmented_eqn)?;
310323

@@ -313,6 +326,7 @@ where
313326
state,
314327
nonlinear_solver,
315328
augmented_eqn.integrate_main_eqn(),
329+
config,
316330
)?;
317331

318332
ret.state.set_augmented_problem(problem, &augmented_eqn)?;
@@ -456,7 +470,7 @@ where
456470
self.state.h = new_h;
457471

458472
// if step size too small, then fail
459-
if self.state.h.abs() < Eqn::T::from(Self::MIN_TIMESTEP) {
473+
if self.state.h.abs() < self.config.minimum_timestep {
460474
return Err(DiffsolError::from(OdeSolverError::StepSizeTooSmall {
461475
time: self.state.t.into(),
462476
}));
@@ -860,6 +874,15 @@ where
860874
for<'b> &'b Eqn::M: MatrixRef<Eqn::M>,
861875
{
862876
type State = BdfState<Eqn::V, M>;
877+
type Config = BdfConfig<Eqn::T>;
878+
879+
fn config(&self) -> &BdfConfig<Eqn::T> {
880+
&self.config
881+
}
882+
883+
fn config_mut(&mut self) -> &mut BdfConfig<Eqn::T> {
884+
&mut self.config
885+
}
863886

864887
fn order(&self) -> usize {
865888
self.state.order
@@ -1126,8 +1149,8 @@ where
11261149
// calculate optimal step size factor as per eq 2.46 of [2]
11271150
// and reduce step size and try again
11281151
let mut factor = safety * error_norm.pow(Eqn::T::from(-0.5 / (order as f64 + 1.0)));
1129-
if factor < Eqn::T::from(Self::MIN_FACTOR) {
1130-
factor = Eqn::T::from(Self::MIN_FACTOR);
1152+
if factor < self.config.minimum_timestep_shrink {
1153+
factor = self.config.minimum_timestep_shrink;
11311154
}
11321155
let new_h = self._update_step_size(factor)?;
11331156
self._jacobian_updates(new_h * self.alpha[order], SolverState::ErrorTestFail);
@@ -1138,7 +1161,7 @@ where
11381161
// update statistics
11391162
self.statistics.number_of_error_test_failures += 1;
11401163
if self.statistics.number_of_error_test_failures - old_num_error_test_failures
1141-
>= Self::MAX_ERROR_TEST_FAILS
1164+
>= self.config.maximum_error_test_failures
11421165
{
11431166
return Err(DiffsolError::from(
11441167
OdeSolverError::TooManyErrorTestFailures {
@@ -1230,14 +1253,14 @@ where
12301253
};
12311254

12321255
let mut factor = safety * factors[max_index];
1233-
if factor > Eqn::T::from(Self::MAX_FACTOR) {
1234-
factor = Eqn::T::from(Self::MAX_FACTOR);
1256+
if factor > self.config.maximum_timestep_growth {
1257+
factor = self.config.maximum_timestep_growth;
12351258
}
1236-
if factor < Eqn::T::from(Self::MIN_FACTOR) {
1237-
factor = Eqn::T::from(Self::MIN_FACTOR);
1259+
if factor < self.config.minimum_timestep_shrink {
1260+
factor = self.config.minimum_timestep_shrink;
12381261
}
1239-
if factor >= Eqn::T::from(Self::MAX_THRESHOLD)
1240-
|| factor < Eqn::T::from(Self::MIN_THRESHOLD)
1262+
if factor >= self.config.minimum_timestep_growth
1263+
|| factor < self.config.maximum_timestep_shrink
12411264
|| max_index == 0
12421265
|| max_index == 2
12431266
{
@@ -1305,8 +1328,8 @@ mod test {
13051328
},
13061329
ode_solver::tests::{
13071330
setup_test_adjoint, setup_test_adjoint_sum_squares, test_adjoint,
1308-
test_adjoint_sum_squares, test_checkpointing, test_interpolate, test_ode_solver,
1309-
test_problem, test_state_mut, test_state_mut_on_problem,
1331+
test_adjoint_sum_squares, test_checkpointing, test_config, test_interpolate,
1332+
test_ode_solver, test_problem, test_state_mut, test_state_mut_on_problem,
13101333
},
13111334
Context, DenseMatrix, FaerLU, FaerMat, FaerSparseLU, FaerSparseMat, MatrixCommon,
13121335
OdeEquations, OdeSolverMethod, Op, Vector, VectorView,
@@ -1321,6 +1344,11 @@ mod test {
13211344
test_state_mut(test_problem::<M>().bdf::<LS>().unwrap());
13221345
}
13231346

1347+
#[test]
1348+
fn bdf_config() {
1349+
test_config(robertson_ode::<M>(false, 1).0.bdf::<LS>().unwrap());
1350+
}
1351+
13241352
#[test]
13251353
fn bdf_test_interpolate() {
13261354
test_interpolate(test_problem::<M>().bdf::<LS>().unwrap());

diffsol/src/ode_solver/config.rs

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
use crate::Scalar;
2+
3+
pub trait OdeSolverConfig<T> {
4+
fn as_base_ref(&self) -> OdeSolverConfigRef<'_, T>;
5+
fn as_base_mut(&mut self) -> OdeSolverConfigMut<'_, T>;
6+
}
7+
pub struct OdeSolverConfigRef<'a, T> {
8+
pub minimum_timestep: &'a T,
9+
pub maximum_error_test_failures: &'a usize,
10+
pub maximum_timestep_growth: &'a T,
11+
pub minimum_timestep_shrink: &'a T,
12+
}
13+
14+
pub struct OdeSolverConfigMut<'a, T> {
15+
pub minimum_timestep: &'a mut T,
16+
pub maximum_error_test_failures: &'a mut usize,
17+
pub maximum_timestep_growth: &'a mut T,
18+
pub minimum_timestep_shrink: &'a mut T,
19+
}
20+
21+
#[derive(Debug, Clone)]
22+
pub struct BdfConfig<T> {
23+
pub minimum_timestep: T,
24+
pub maximum_error_test_failures: usize,
25+
pub maximum_timestep_growth: T,
26+
pub minimum_timestep_growth: T,
27+
pub maximum_timestep_shrink: T,
28+
pub minimum_timestep_shrink: T,
29+
pub maximum_newton_iterations: usize,
30+
}
31+
32+
impl<T: Scalar> OdeSolverConfig<T> for BdfConfig<T> {
33+
fn as_base_ref(&self) -> OdeSolverConfigRef<'_, T> {
34+
OdeSolverConfigRef {
35+
minimum_timestep: &self.minimum_timestep,
36+
maximum_error_test_failures: &self.maximum_error_test_failures,
37+
maximum_timestep_growth: &self.maximum_timestep_growth,
38+
minimum_timestep_shrink: &self.minimum_timestep_shrink,
39+
}
40+
}
41+
42+
fn as_base_mut(&mut self) -> OdeSolverConfigMut<'_, T> {
43+
OdeSolverConfigMut {
44+
minimum_timestep: &mut self.minimum_timestep,
45+
maximum_error_test_failures: &mut self.maximum_error_test_failures,
46+
maximum_timestep_growth: &mut self.maximum_timestep_growth,
47+
minimum_timestep_shrink: &mut self.minimum_timestep_shrink,
48+
}
49+
}
50+
}
51+
52+
impl<T: Scalar> Default for BdfConfig<T> {
53+
fn default() -> Self {
54+
Self {
55+
minimum_timestep: T::from(1e-32),
56+
maximum_error_test_failures: 40,
57+
maximum_timestep_growth: T::from(2.1),
58+
minimum_timestep_growth: T::from(2.0),
59+
maximum_timestep_shrink: T::from(0.9),
60+
minimum_timestep_shrink: T::from(0.5),
61+
maximum_newton_iterations: 4,
62+
}
63+
}
64+
}
65+
66+
#[derive(Debug, Clone)]
67+
pub struct SdirkConfig<T> {
68+
pub minimum_timestep: T,
69+
pub maximum_error_test_failures: usize,
70+
pub maximum_timestep_growth: T,
71+
pub minimum_timestep_shrink: T,
72+
pub maximum_newton_iterations: usize,
73+
}
74+
75+
impl<T: Scalar> Default for SdirkConfig<T> {
76+
fn default() -> Self {
77+
Self {
78+
minimum_timestep: T::from(1e-13),
79+
maximum_error_test_failures: 40,
80+
maximum_timestep_growth: T::from(10.0),
81+
minimum_timestep_shrink: T::from(0.2),
82+
maximum_newton_iterations: 10,
83+
}
84+
}
85+
}
86+
87+
impl<T: Scalar> OdeSolverConfig<T> for SdirkConfig<T> {
88+
fn as_base_ref(&self) -> OdeSolverConfigRef<'_, T> {
89+
OdeSolverConfigRef {
90+
minimum_timestep: &self.minimum_timestep,
91+
maximum_error_test_failures: &self.maximum_error_test_failures,
92+
maximum_timestep_growth: &self.maximum_timestep_growth,
93+
minimum_timestep_shrink: &self.minimum_timestep_shrink,
94+
}
95+
}
96+
97+
fn as_base_mut(&mut self) -> OdeSolverConfigMut<'_, T> {
98+
OdeSolverConfigMut {
99+
minimum_timestep: &mut self.minimum_timestep,
100+
maximum_error_test_failures: &mut self.maximum_error_test_failures,
101+
maximum_timestep_growth: &mut self.maximum_timestep_growth,
102+
minimum_timestep_shrink: &mut self.minimum_timestep_shrink,
103+
}
104+
}
105+
}
106+
107+
#[derive(Debug, Clone)]
108+
pub struct ExplicitRkConfig<T> {
109+
pub minimum_timestep: T,
110+
pub maximum_error_test_failures: usize,
111+
pub maximum_timestep_growth: T,
112+
pub minimum_timestep_shrink: T,
113+
}
114+
115+
impl<T: Scalar> Default for ExplicitRkConfig<T> {
116+
fn default() -> Self {
117+
Self {
118+
minimum_timestep: T::from(1e-13),
119+
maximum_error_test_failures: 40,
120+
maximum_timestep_growth: T::from(10.0),
121+
minimum_timestep_shrink: T::from(0.2),
122+
}
123+
}
124+
}
125+
126+
impl<T: Scalar> OdeSolverConfig<T> for ExplicitRkConfig<T> {
127+
fn as_base_ref(&self) -> OdeSolverConfigRef<'_, T> {
128+
OdeSolverConfigRef {
129+
minimum_timestep: &self.minimum_timestep,
130+
maximum_error_test_failures: &self.maximum_error_test_failures,
131+
maximum_timestep_growth: &self.maximum_timestep_growth,
132+
minimum_timestep_shrink: &self.minimum_timestep_shrink,
133+
}
134+
}
135+
136+
fn as_base_mut(&mut self) -> OdeSolverConfigMut<'_, T> {
137+
OdeSolverConfigMut {
138+
minimum_timestep: &mut self.minimum_timestep,
139+
maximum_error_test_failures: &mut self.maximum_error_test_failures,
140+
maximum_timestep_growth: &mut self.maximum_timestep_growth,
141+
minimum_timestep_shrink: &mut self.minimum_timestep_shrink,
142+
}
143+
}
144+
}

0 commit comments

Comments
 (0)