Skip to content

Commit a18916c

Browse files
committed
fix: solve and solve_dense respect roots in problem
1 parent 413341b commit a18916c

File tree

1 file changed

+109
-8
lines changed

1 file changed

+109
-8
lines changed

diffsol/src/ode_solver/method.rs

Lines changed: 109 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ where
161161
///
162162
/// # Post-condition
163163
/// After the solver finishes, the internal state of the solver is at time `final_time`.
164+
/// If a root is found, the solver stops early. The internal state is moved to the root time,
165+
/// and the root time/value are returned as the last entry.
164166
#[allow(clippy::type_complexity)]
165167
fn solve(
166168
&mut self,
@@ -176,12 +178,40 @@ where
176178
// do the main loop
177179
write_out(self, &mut ret_y, &mut ret_t, &mut tmp_nout);
178180
self.set_stop_time(final_time)?;
179-
while self.step()? != OdeSolverStopReason::TstopReached {
180-
write_out(self, &mut ret_y, &mut ret_t, &mut tmp_nout);
181+
loop {
182+
match self.step()? {
183+
OdeSolverStopReason::InternalTimestep => {
184+
write_out(self, &mut ret_y, &mut ret_t, &mut tmp_nout);
185+
}
186+
OdeSolverStopReason::TstopReached => {
187+
write_out(self, &mut ret_y, &mut ret_t, &mut tmp_nout);
188+
break;
189+
}
190+
OdeSolverStopReason::RootFound(t_root) => {
191+
let nstates = self.problem().eqn.rhs().nstates();
192+
let mut y_root =
193+
Eqn::V::zeros(nstates, self.problem().context().clone());
194+
self.interpolate_inplace(t_root, &mut y_root)?;
195+
let integrate_out = self.problem().integrate_out;
196+
let mut g_root = None;
197+
if integrate_out {
198+
let mut g = self.state().g.clone();
199+
self.interpolate_out_inplace(t_root, &mut g)?;
200+
g_root = Some(g);
201+
}
202+
{
203+
let state = self.state_mut();
204+
state.y.copy_from(&y_root);
205+
*state.t = t_root;
206+
if let Some(g) = g_root.as_ref() {
207+
state.g.copy_from(g);
208+
}
209+
}
210+
write_out(self, &mut ret_y, &mut ret_t, &mut tmp_nout);
211+
break;
212+
}
213+
}
181214
}
182-
183-
// store the final step
184-
write_out(self, &mut ret_y, &mut ret_t, &mut tmp_nout);
185215
let ntimes = ret_t.len();
186216
ret_y.resize_cols(ntimes);
187217
Ok((ret_y, ret_t))
@@ -202,6 +232,8 @@ where
202232
///
203233
/// # Post-condition
204234
/// After the solver finishes, the internal state of the solver is at time `t_eval[t_eval.len()-1]`.
235+
/// If a root is found, the solver stops early. The internal state is moved to the root time,
236+
/// and the last column corresponds to the root time (which may not be in `t_eval`).
205237
fn solve_dense(
206238
&mut self,
207239
t_eval: &[Eqn::T],
@@ -214,14 +246,49 @@ where
214246

215247
// do loop
216248
self.set_stop_time(t_eval[t_eval.len() - 1])?;
217-
let mut step_reason = OdeSolverStopReason::InternalTimestep;
218249
for (i, t) in t_eval.iter().enumerate() {
219250
while self.state().t < *t {
220-
step_reason = self.step()?;
251+
match self.step()? {
252+
OdeSolverStopReason::InternalTimestep => {}
253+
OdeSolverStopReason::TstopReached => break,
254+
OdeSolverStopReason::RootFound(t_root) => {
255+
self.interpolate_inplace(t_root, &mut tmp_nstates)?;
256+
let integrate_out = self.problem().integrate_out;
257+
let mut g_root = None;
258+
if integrate_out {
259+
let mut g = self.state().g.clone();
260+
self.interpolate_out_inplace(t_root, &mut g)?;
261+
g_root = Some(g);
262+
}
263+
{
264+
let state = self.state_mut();
265+
state.y.copy_from(&tmp_nstates);
266+
*state.t = t_root;
267+
if let Some(g) = g_root.as_ref() {
268+
state.g.copy_from(g);
269+
}
270+
}
271+
{
272+
let mut y_out = ret.column_mut(i);
273+
if integrate_out {
274+
y_out.copy_from(g_root.as_ref().unwrap());
275+
} else {
276+
match self.problem().eqn.out() {
277+
Some(out) => {
278+
out.call_inplace(&tmp_nstates, t_root, &mut tmp_nout);
279+
y_out.copy_from(&tmp_nout);
280+
}
281+
None => y_out.copy_from(&tmp_nstates),
282+
}
283+
}
284+
}
285+
ret.resize_cols(i + 1);
286+
return Ok(ret);
287+
}
288+
}
221289
}
222290
dense_write_out(self, &mut ret, t_eval, i, &mut tmp_nout, &mut tmp_nstates)?;
223291
}
224-
assert_eq!(step_reason, OdeSolverStopReason::TstopReached);
225292
Ok(ret)
226293
}
227294

@@ -543,10 +610,12 @@ where
543610
mod test {
544611
use crate::{
545612
error::{DiffsolError, OdeSolverError},
613+
matrix::MatrixCommon,
546614
matrix::dense_nalgebra_serial::NalgebraMat,
547615
ode_equations::test_models::exponential_decay::{
548616
exponential_decay_problem, exponential_decay_problem_adjoint,
549617
exponential_decay_problem_sens, exponential_decay_problem_sens_with_out,
618+
exponential_decay_problem_with_root,
550619
},
551620
scale, AdjointOdeSolverMethod, DenseMatrix, NalgebraLU, NalgebraVec, OdeEquations,
552621
OdeSolverMethod, Op, SensitivitiesOdeSolverMethod, Vector, VectorView,
@@ -569,6 +638,22 @@ mod test {
569638
}
570639
}
571640

641+
#[test]
642+
fn test_solve_stops_on_root() {
643+
let (problem, _soln) = exponential_decay_problem_with_root::<NalgebraMat<f64>>(false);
644+
let mut s = problem.bdf::<NalgebraLU<f64>>().unwrap();
645+
646+
let (y, t) = s.solve(10.0).unwrap();
647+
let t_root = -0.6_f64.ln() / 0.1;
648+
let t_last = *t.last().unwrap();
649+
assert!((t_last - t_root).abs() < 1e-3);
650+
assert!((s.state().t - t_root).abs() < 1e-3);
651+
652+
let y_last = y.column(y.ncols() - 1).into_owned();
653+
let expected = NalgebraVec::from_vec(vec![0.6, 0.6], *problem.context());
654+
y_last.assert_eq_norm(&expected, &problem.atol, problem.rtol, 15.0);
655+
}
656+
572657
#[test]
573658
fn test_solve_integrate_out() {
574659
let (problem, _soln) = exponential_decay_problem_adjoint::<NalgebraMat<f64>>(true);
@@ -604,6 +689,22 @@ mod test {
604689
}
605690
}
606691

692+
#[test]
693+
fn test_dense_solve_stops_on_root() {
694+
let (problem, _soln) = exponential_decay_problem_with_root::<NalgebraMat<f64>>(false);
695+
let mut s = problem.bdf::<NalgebraLU<f64>>().unwrap();
696+
697+
let t_eval = (0..=10).map(|i| i as f64).collect::<Vec<_>>();
698+
let y = s.solve_dense(t_eval.as_slice()).unwrap();
699+
let t_root = -0.6_f64.ln() / 0.1;
700+
assert!((s.state().t - t_root).abs() < 1e-3);
701+
assert!(y.ncols() < t_eval.len());
702+
703+
let y_last = y.column(y.ncols() - 1).into_owned();
704+
let expected = NalgebraVec::from_vec(vec![0.6, 0.6], *problem.context());
705+
y_last.assert_eq_norm(&expected, &problem.atol, problem.rtol, 15.0);
706+
}
707+
607708
#[test]
608709
fn test_dense_solve_integrate_out() {
609710
let (problem, soln) = exponential_decay_problem_adjoint::<NalgebraMat<f64>>(true);

0 commit comments

Comments
 (0)