Skip to content

Commit 73bd6b7

Browse files
refactor: tidy up output in sens_dense_sensitivities (#186)
* refactor: tidy up output in sens_dense_sensitivities * add test for invalid t_eval
1 parent 72535f8 commit 73bd6b7

File tree

2 files changed

+94
-81
lines changed

2 files changed

+94
-81
lines changed

diffsol/src/ode_solver/method.rs

Lines changed: 88 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -9,79 +9,6 @@ use crate::{
99
OdeSolverState, Op, StateRef, StateRefMut, Vector, VectorViewMut,
1010
};
1111

12-
/// Utility function to write out the solution at a given timepoint
13-
/// This function is used by the `solve_dense` method to write out the solution at a given timepoint.
14-
fn dense_write_out<'a, Eqn: OdeEquations + 'a, S: OdeSolverMethod<'a, Eqn>>(
15-
s: &S,
16-
y_out: &mut <Eqn::V as DefaultDenseMatrix>::M,
17-
t_eval: &[Eqn::T],
18-
i: usize,
19-
) -> Result<(), DiffsolError>
20-
where
21-
Eqn::V: DefaultDenseMatrix,
22-
{
23-
let mut y_out = y_out.column_mut(i);
24-
let t = t_eval[i];
25-
if s.problem().integrate_out {
26-
let g = s.interpolate_out(t)?;
27-
y_out.copy_from(&g);
28-
} else {
29-
let y = s.interpolate(t)?;
30-
match s.problem().eqn.out() {
31-
Some(out) => y_out.copy_from(&out.call(&y, t_eval[i])),
32-
None => y_out.copy_from(&y),
33-
}
34-
}
35-
Ok(())
36-
}
37-
38-
/// utility function to write out the solution at a given timepoint
39-
/// This function is used by the `solve` method to write out the solution at a given timepoint.
40-
fn write_out<'a, Eqn: OdeEquations + 'a, S: OdeSolverMethod<'a, Eqn>>(
41-
s: &S,
42-
ret_y: &mut Vec<Eqn::V>,
43-
ret_t: &mut Vec<Eqn::T>,
44-
) {
45-
let t = s.state().t;
46-
let y = s.state().y;
47-
ret_t.push(t);
48-
match s.problem().eqn.out() {
49-
Some(out) => {
50-
if s.problem().integrate_out {
51-
ret_y.push(s.state().g.clone());
52-
} else {
53-
ret_y.push(out.call(y, t));
54-
}
55-
}
56-
None => ret_y.push(y.clone()),
57-
}
58-
}
59-
60-
fn dense_allocate_return<'a, Eqn: OdeEquations + 'a, S: OdeSolverMethod<'a, Eqn>>(
61-
s: &S,
62-
t_eval: &[Eqn::T],
63-
) -> Result<<Eqn::V as DefaultDenseMatrix>::M, DiffsolError>
64-
where
65-
Eqn::V: DefaultDenseMatrix,
66-
{
67-
let nrows = if s.problem().eqn.out().is_some() {
68-
s.problem().eqn.out().unwrap().nout()
69-
} else {
70-
s.problem().eqn.rhs().nstates()
71-
};
72-
let ret = s
73-
.problem()
74-
.context()
75-
.dense_mat_zeros::<Eqn::V>(nrows, t_eval.len());
76-
77-
// check t_eval is increasing and all values are greater than or equal to the current time
78-
let t0 = s.state().t;
79-
if t_eval.windows(2).any(|w| w[0] > w[1] || w[0] < t0) {
80-
return Err(ode_solver_error!(InvalidTEval));
81-
}
82-
Ok(ret)
83-
}
84-
8512
#[derive(Debug, PartialEq)]
8613
pub enum OdeSolverStopReason<T: Scalar> {
8714
InternalTimestep,
@@ -400,9 +327,85 @@ where
400327
fn augmented_eqn(&self) -> Option<&AugmentedEqn>;
401328
}
402329

330+
/// Utility function to write out the solution at a given timepoint
331+
/// This function is used by the `solve_dense` method to write out the solution at a given timepoint.
332+
fn dense_write_out<'a, Eqn: OdeEquations + 'a, S: OdeSolverMethod<'a, Eqn>>(
333+
s: &S,
334+
y_out: &mut <Eqn::V as DefaultDenseMatrix>::M,
335+
t_eval: &[Eqn::T],
336+
i: usize,
337+
) -> Result<(), DiffsolError>
338+
where
339+
Eqn::V: DefaultDenseMatrix,
340+
{
341+
let mut y_out = y_out.column_mut(i);
342+
let t = t_eval[i];
343+
if s.problem().integrate_out {
344+
let g = s.interpolate_out(t)?;
345+
y_out.copy_from(&g);
346+
} else {
347+
let y = s.interpolate(t)?;
348+
match s.problem().eqn.out() {
349+
Some(out) => y_out.copy_from(&out.call(&y, t_eval[i])),
350+
None => y_out.copy_from(&y),
351+
}
352+
}
353+
Ok(())
354+
}
355+
356+
/// utility function to write out the solution at a given timepoint
357+
/// This function is used by the `solve` method to write out the solution at a given timepoint.
358+
fn write_out<'a, Eqn: OdeEquations + 'a, S: OdeSolverMethod<'a, Eqn>>(
359+
s: &S,
360+
ret_y: &mut Vec<Eqn::V>,
361+
ret_t: &mut Vec<Eqn::T>,
362+
) {
363+
let t = s.state().t;
364+
let y = s.state().y;
365+
ret_t.push(t);
366+
match s.problem().eqn.out() {
367+
Some(out) => {
368+
if s.problem().integrate_out {
369+
ret_y.push(s.state().g.clone());
370+
} else {
371+
ret_y.push(out.call(y, t));
372+
}
373+
}
374+
None => ret_y.push(y.clone()),
375+
}
376+
}
377+
378+
/// Utility function to allocate the return matrix for the `solve_dense`
379+
/// and `solve_dense_sensitivities` methods.
380+
fn dense_allocate_return<'a, Eqn: OdeEquations + 'a, S: OdeSolverMethod<'a, Eqn>>(
381+
s: &S,
382+
t_eval: &[Eqn::T],
383+
) -> Result<<Eqn::V as DefaultDenseMatrix>::M, DiffsolError>
384+
where
385+
Eqn::V: DefaultDenseMatrix,
386+
{
387+
let nrows = if s.problem().eqn.out().is_some() {
388+
s.problem().eqn.out().unwrap().nout()
389+
} else {
390+
s.problem().eqn.rhs().nstates()
391+
};
392+
let ret = s
393+
.problem()
394+
.context()
395+
.dense_mat_zeros::<Eqn::V>(nrows, t_eval.len());
396+
397+
// check t_eval is increasing and all values are greater than or equal to the current time
398+
let t0 = s.state().t;
399+
if t_eval.windows(2).any(|w| w[0] > w[1] || w[0] < t0) {
400+
return Err(ode_solver_error!(InvalidTEval));
401+
}
402+
Ok(ret)
403+
}
404+
403405
#[cfg(test)]
404406
mod test {
405407
use crate::{
408+
error::{DiffsolError, OdeSolverError},
406409
matrix::dense_nalgebra_serial::NalgebraMat,
407410
ode_equations::test_models::exponential_decay::{
408411
exponential_decay_problem, exponential_decay_problem_adjoint,
@@ -477,6 +480,18 @@ mod test {
477480
}
478481
}
479482

483+
#[test]
484+
fn test_t_eval_errors() {
485+
let (problem, _soln) = exponential_decay_problem::<NalgebraMat<f64>>(false);
486+
let mut s = problem.bdf::<NalgebraLU<f64>>().unwrap();
487+
let t_eval = vec![0.0, 1.0, 0.5, 2.0];
488+
let err = s.solve_dense(t_eval.as_slice()).unwrap_err();
489+
assert!(matches!(
490+
err,
491+
DiffsolError::OdeSolverError(OdeSolverError::InvalidTEval)
492+
));
493+
}
494+
480495
#[test]
481496
fn test_dense_solve_sensitivities() {
482497
let (problem, soln) = exponential_decay_problem_sens::<NalgebraMat<f64>>(false);

diffsol/src/ode_solver/sensitivities.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,12 @@ where
8585
ret.column_mut(i).copy_from(tmp_nout);
8686
for (j, s_j) in s.iter_mut().enumerate() {
8787
// compute J * s_j + dF/dp * e_j where e_j is the jth basis vector
88+
let mut ret_sens = ret_sens[j].column_mut(i);
8889
tmp_nparams.set_index(j, Eqn::T::one());
8990
out.jac_mul_inplace(&y, *t, s_j, tmp_nout);
90-
s_j.copy_from(tmp_nout);
91+
ret_sens.copy_from(tmp_nout);
9192
out.sens_mul_inplace(&y, *t, tmp_nparams, tmp_nout);
92-
s_j.add_assign(&*tmp_nout);
93-
ret_sens[j].column_mut(i).copy_from(s_j);
93+
ret_sens.add_assign(&*tmp_nout);
9494
tmp_nparams.set_index(j, Eqn::T::zero());
9595
}
9696
} else {
@@ -107,23 +107,21 @@ where
107107
}
108108
let y = self.state().y;
109109
let s = self.state().s;
110-
let mut s_tmp = tmp_nout.clone();
111110
let i = t_eval.len() - 1;
112111
let t = t_eval.last().unwrap();
113112
if let Some(out) = self.problem().eqn.out() {
114113
let tmp_nout = tmp_nout.as_mut().unwrap();
115114
let tmp_nparams = tmp_nparms.as_mut().unwrap();
116-
let s_tmp = s_tmp.as_mut().unwrap();
117115
out.call_inplace(y, *t, tmp_nout);
118116
ret.column_mut(i).copy_from(tmp_nout);
119117
for (j, s_j) in s.iter().enumerate() {
120118
// compute J * s_j + dF/dp * e_j where e_j is the jth basis vector
119+
let mut ret_sens = ret_sens[j].column_mut(i);
121120
tmp_nparams.set_index(j, Eqn::T::one());
122121
out.jac_mul_inplace(y, *t, s_j, tmp_nout);
123-
s_tmp.copy_from(tmp_nout);
122+
ret_sens.copy_from(tmp_nout);
124123
out.sens_mul_inplace(y, *t, tmp_nparams, tmp_nout);
125-
s_tmp.add_assign(&*tmp_nout);
126-
ret_sens[j].column_mut(i).copy_from(s_tmp);
124+
ret_sens.add_assign(&*tmp_nout);
127125
tmp_nparams.set_index(j, Eqn::T::zero());
128126
}
129127
} else {

0 commit comments

Comments
 (0)