Skip to content

Commit e575b26

Browse files
perf: avoid Jacobian resets for start/endpoint checkpoints (#237)
* Avoid Jacobian reset for initial and final checkpoints * update snapshots
1 parent df915d6 commit e575b26

File tree

4 files changed

+33
-19
lines changed

4 files changed

+33
-19
lines changed

diffsol/src/ode_solver/bdf.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ where
140140
let op = if let Some(op) = self.op.as_ref() {
141141
let op = op.clone_state(&self.ode_problem.eqn);
142142
nonlinear_solver.set_problem(&op);
143-
nonlinear_solver.reset_jacobian(&op, &self.state.y, self.state.t);
144143
Some(op)
145144
} else {
146145
None
@@ -1138,6 +1137,10 @@ where
11381137
self.state.clone()
11391138
}
11401139

1140+
fn state_clone(&self) -> Self::State {
1141+
self.state.clone()
1142+
}
1143+
11411144
fn step(&mut self) -> Result<OdeSolverStopReason<Eqn::T>, DiffsolError> {
11421145
debug!(
11431146
"Taking BDF step at time {} with step size {} and order {}",
@@ -1663,8 +1666,8 @@ mod test {
16631666
test_adjoint(adjoint_solver, dgdu);
16641667
insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###"
16651668
number_of_calls: 160
1666-
number_of_jac_muls: 6
1667-
number_of_matrix_evals: 3
1669+
number_of_jac_muls: 2
1670+
number_of_matrix_evals: 1
16681671
number_of_jac_adj_muls: 220
16691672
"###);
16701673
}
@@ -1685,8 +1688,8 @@ mod test {
16851688
test_adjoint_sum_squares(adjoint_solver, dgdp, soln, data, times.as_slice());
16861689
insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###"
16871690
number_of_calls: 588
1688-
number_of_jac_muls: 10
1689-
number_of_matrix_evals: 5
1691+
number_of_jac_muls: 6
1692+
number_of_matrix_evals: 3
16901693
number_of_jac_adj_muls: 1054
16911694
"###);
16921695
}
@@ -1737,8 +1740,8 @@ mod test {
17371740
test_adjoint(adjoint_solver, dgdu);
17381741
insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###"
17391742
number_of_calls: 163
1740-
number_of_jac_muls: 18
1741-
number_of_matrix_evals: 6
1743+
number_of_jac_muls: 12
1744+
number_of_matrix_evals: 4
17421745
number_of_jac_adj_muls: 106
17431746
"###);
17441747
}
@@ -1759,8 +1762,8 @@ mod test {
17591762
test_adjoint_sum_squares(adjoint_solver, dgdp, soln, data, times.as_slice());
17601763
insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###"
17611764
number_of_calls: 220
1762-
number_of_jac_muls: 18
1763-
number_of_matrix_evals: 6
1765+
number_of_jac_muls: 12
1766+
number_of_matrix_evals: 4
17641767
number_of_jac_adj_muls: 404
17651768
"###);
17661769
}

diffsol/src/ode_solver/explicit_rk.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ where
159159
self.rk.checkpoint()
160160
}
161161

162+
fn state_clone(&self) -> Self::State {
163+
self.rk.state().clone()
164+
}
165+
162166
fn step(&mut self) -> Result<OdeSolverStopReason<Eqn::T>, DiffsolError> {
163167
let mut h = self.rk.start_step()?;
164168
debug!(

diffsol/src/ode_solver/method.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ where
5656
/// Note that this will force a reinitialisation of the internal Jacobian for the solver, if it has one.
5757
fn checkpoint(&mut self) -> Self::State;
5858

59+
/// Clone the current state of the solver without triggering any internal Jacobian reset.
60+
fn state_clone(&self) -> Self::State;
61+
5962
/// Replace the current state of the solver with a new state.
6063
fn set_state(&mut self, state: Self::State);
6164

@@ -329,7 +332,7 @@ where
329332
// allocate checkpoint info
330333
let mut nsteps = 0;
331334
let t0 = self.state().t;
332-
let mut checkpoints = vec![self.checkpoint()];
335+
let mut checkpoints = vec![self.state_clone()];
333336
let mut ts = vec![t0];
334337
let mut ys = vec![self.state().y.clone()];
335338
let mut ydots = vec![self.state().dy.clone()];
@@ -361,7 +364,7 @@ where
361364
ts.push(self.state().t);
362365
ys.push(self.state().y.clone());
363366
ydots.push(self.state().dy.clone());
364-
checkpoints.push(self.checkpoint());
367+
checkpoints.push(self.state_clone());
365368

366369
// construct checkpointing
367370
let last_segment = HermiteInterpolator::new(ys, ydots, ts);
@@ -420,7 +423,7 @@ where
420423
// allocate checkpoint info
421424
let mut nsteps = 0;
422425
let t0 = self.state().t;
423-
let mut checkpoints = vec![self.checkpoint()];
426+
let mut checkpoints = vec![self.state_clone()];
424427
let mut ts = vec![t0];
425428
let mut ys = vec![self.state().y.clone()];
426429
let mut ydots = vec![self.state().dy.clone()];
@@ -450,7 +453,7 @@ where
450453
assert_eq!(step_reason, OdeSolverStopReason::TstopReached);
451454

452455
// add final checkpoint
453-
checkpoints.push(self.checkpoint());
456+
checkpoints.push(self.state_clone());
454457

455458
// construct the adjoint equations
456459
let last_segment = HermiteInterpolator::new(ys, ydots, ts);

diffsol/src/ode_solver/sdirk.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,10 @@ where
337337
self.rk.state().clone()
338338
}
339339

340+
fn state_clone(&self) -> Self::State {
341+
self.rk.state().clone()
342+
}
343+
340344
fn step(&mut self) -> Result<OdeSolverStopReason<Eqn::T>, DiffsolError> {
341345
debug!(
342346
"Taking SDIRK step at time {}, step size {}",
@@ -695,8 +699,8 @@ mod test {
695699
test_adjoint(adjoint_solver, dgdu);
696700
insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###"
697701
number_of_calls: 419
698-
number_of_jac_muls: 12
699-
number_of_matrix_evals: 6
702+
number_of_jac_muls: 10
703+
number_of_matrix_evals: 5
700704
number_of_jac_adj_muls: 271
701705
"###);
702706
}
@@ -717,8 +721,8 @@ mod test {
717721
test_adjoint_sum_squares(adjoint_solver, dgdp, soln, data, times.as_slice());
718722
insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###"
719723
number_of_calls: 362
720-
number_of_jac_muls: 12
721-
number_of_matrix_evals: 4
724+
number_of_jac_muls: 9
725+
number_of_matrix_evals: 3
722726
number_of_jac_adj_muls: 686
723727
"###);
724728
}
@@ -736,8 +740,8 @@ mod test {
736740
test_adjoint(adjoint_solver, dgdu);
737741
insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###"
738742
number_of_calls: 433
739-
number_of_jac_muls: 30
740-
number_of_matrix_evals: 10
743+
number_of_jac_muls: 27
744+
number_of_matrix_evals: 9
741745
number_of_jac_adj_muls: 169
742746
"###);
743747
}

0 commit comments

Comments
 (0)