Skip to content

Commit 72535f8

Browse files
feat: sensitivity dense out now supports output ops (#184)
* update ode equations trait and builder * finish loop, need to fix the last state * works * remove odeequationssens as unnecessary
1 parent 3d3cdc1 commit 72535f8

File tree

7 files changed

+225
-31
lines changed

7 files changed

+225
-31
lines changed

diffsol/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,8 @@ pub use ode_equations::{
189189
adjoint_equations::AdjointInit, adjoint_equations::AdjointRhs, sens_equations::SensEquations,
190190
sens_equations::SensInit, sens_equations::SensRhs, AugmentedOdeEquations,
191191
AugmentedOdeEquationsImplicit, NoAug, OdeEquations, OdeEquationsAdjoint, OdeEquationsImplicit,
192-
OdeEquationsImplicitAdjoint, OdeEquationsImplicitSens, OdeEquationsRef, OdeEquationsSens,
193-
OdeEquationsStoch, OdeSolverEquations,
192+
OdeEquationsImplicitAdjoint, OdeEquationsImplicitSens, OdeEquationsRef, OdeEquationsStoch,
193+
OdeSolverEquations,
194194
};
195195
use ode_solver::jacobian_update::JacobianUpdate;
196196
pub use ode_solver::sde::SdeSolverMethod;

diffsol/src/ode_equations/mod.rs

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -347,25 +347,11 @@ impl<T> OdeEquationsStoch for T where
347347
{
348348
}
349349

350-
pub trait OdeEquationsSens:
351-
OdeEquations<
352-
Rhs: NonLinearOpSens<M = Self::M, V = Self::V, T = Self::T, C = Self::C>,
353-
Init: ConstantOpSens<M = Self::M, V = Self::V, T = Self::T, C = Self::C>,
354-
>
355-
{
356-
}
357-
358-
impl<T> OdeEquationsSens for T where
359-
T: OdeEquations<
360-
Rhs: NonLinearOpSens<M = T::M, V = T::V, T = T::T, C = T::C>,
361-
Init: ConstantOpSens<M = T::M, V = T::V, T = T::T, C = T::C>,
362-
>
363-
{
364-
}
365-
366350
pub trait OdeEquationsImplicitSens:
367351
OdeEquationsImplicit<
368352
Rhs: NonLinearOpSens<M = Self::M, V = Self::V, T = Self::T, C = Self::C>,
353+
Out: NonLinearOpSens<M = Self::M, V = Self::V, T = Self::T, C = Self::C>
354+
+ NonLinearOpJacobian<M = Self::M, V = Self::V, T = Self::T, C = Self::C>,
369355
Init: ConstantOpSens<M = Self::M, V = Self::V, T = Self::T, C = Self::C>,
370356
>
371357
{
@@ -374,6 +360,8 @@ pub trait OdeEquationsImplicitSens:
374360
impl<T> OdeEquationsImplicitSens for T where
375361
T: OdeEquationsImplicit<
376362
Rhs: NonLinearOpSens<M = T::M, V = T::V, T = T::T, C = T::C>,
363+
Out: NonLinearOpSens<M = T::M, V = T::V, T = T::T, C = T::C>
364+
+ NonLinearOpJacobian<M = T::M, V = T::V, T = T::T, C = T::C>,
377365
Init: ConstantOpSens<M = T::M, V = T::V, T = T::T, C = T::C>,
378366
>
379367
{

diffsol/src/ode_equations/sens_equations.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::cell::RefCell;
44
use crate::{
55
op::nonlinear_op::NonLinearOpJacobian, AugmentedOdeEquations, ConstantOp, ConstantOpSens,
66
Matrix, NonLinearOp, NonLinearOpSens, OdeEquations, OdeEquationsImplicitSens, OdeEquationsRef,
7-
OdeEquationsSens, OdeSolverProblem, Op, Vector,
7+
OdeSolverProblem, Op, Vector,
88
};
99

1010
pub struct SensInit<'a, Eqn>
@@ -64,7 +64,7 @@ where
6464

6565
impl<Eqn> ConstantOp for SensInit<'_, Eqn>
6666
where
67-
Eqn: OdeEquationsSens,
67+
Eqn: OdeEquationsImplicitSens,
6868
{
6969
fn call_inplace(&self, _t: Self::T, y: &mut Self::V) {
7070
self.eqn.init().sens_mul_inplace(self.t0, &self.tmp, y);

diffsol/src/ode_equations/test_models/exponential_decay.rs

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,18 @@ fn exponential_decay_out_adj_mul<M: MatrixHost>(
123123
y[1] = -M::T::from(2.0) * v[0] - M::T::from(4.0) * v[1];
124124
}
125125

126+
/// J = |0 0|
127+
/// |0 0|
128+
fn exponential_decay_out_sens<M: MatrixHost>(
129+
_x: &M::V,
130+
_p: &M::V,
131+
_t: M::T,
132+
_v: &M::V,
133+
y: &mut M::V,
134+
) {
135+
y.fill(M::T::zero());
136+
}
137+
126138
/// J = |0 0|
127139
/// |0 0|
128140
fn exponential_decay_out_sens_adj<M: MatrixHost>(
@@ -422,6 +434,78 @@ pub fn exponential_decay_problem_sens<M: MatrixHost + 'static>(
422434
(problem, soln)
423435
}
424436

437+
#[allow(clippy::type_complexity)]
438+
pub fn exponential_decay_problem_sens_with_out<M: MatrixHost + 'static>(
439+
use_coloring: bool,
440+
) -> (
441+
OdeSolverProblem<impl OdeEquationsImplicitSens<M = M, V = M::V, T = M::T, C = M::C>>,
442+
OdeSolverSolution<M::V>,
443+
) {
444+
let k = 0.1;
445+
let y0 = 1.0;
446+
let problem = OdeBuilder::<M>::new()
447+
.p([k, y0])
448+
.sens_rtol(1e-6)
449+
.sens_atol([1e-6, 1e-6])
450+
.use_coloring(use_coloring)
451+
.rhs_sens_implicit(
452+
exponential_decay::<M>,
453+
exponential_decay_jacobian::<M>,
454+
exponential_decay_sens::<M>,
455+
)
456+
.init_sens(
457+
exponential_decay_init::<M>,
458+
exponential_decay_init_sens::<M>,
459+
2,
460+
)
461+
// g_1 = 1 * x_1 + 2 * x_2
462+
// g_2 = 3 * x_1 + 4 * x_2
463+
.out_sens_implicit(
464+
exponential_decay_out::<M>,
465+
exponential_decay_out_jac_mul::<M>,
466+
exponential_decay_out_sens::<M>,
467+
2,
468+
)
469+
.build()
470+
.unwrap();
471+
let p = [M::T::from(k), M::T::from(y0)];
472+
let mut soln = OdeSolverSolution::default();
473+
474+
for i in 0..10 {
475+
let t = M::T::from(i as f64);
476+
let y0: M::V = problem.eqn.init().call(M::T::zero());
477+
let y = y0.clone() * scale(M::T::exp(-p[0] * t));
478+
let y_out = M::V::from_vec(
479+
vec![
480+
M::T::from(1.0) * y[0] + M::T::from(2.0) * y[1],
481+
M::T::from(3.0) * y[0] + M::T::from(4.0) * y[1],
482+
],
483+
y.context().clone(),
484+
);
485+
let ypk = y0 * scale(-t * M::T::exp(-p[0] * t));
486+
let ypk_out = M::V::from_vec(
487+
vec![
488+
M::T::from(1.0) * ypk[0] + M::T::from(2.0) * ypk[1],
489+
M::T::from(3.0) * ypk[0] + M::T::from(4.0) * ypk[1],
490+
],
491+
y.context().clone(),
492+
);
493+
let ypy0 = M::V::from_vec(
494+
vec![(-p[0] * t).exp(), (-p[0] * t).exp()],
495+
y.context().clone(),
496+
);
497+
let ypy0_out = M::V::from_vec(
498+
vec![
499+
M::T::from(1.0) * ypy0[0] + M::T::from(2.0) * ypy0[1],
500+
M::T::from(3.0) * ypy0[0] + M::T::from(4.0) * ypy0[1],
501+
],
502+
y.context().clone(),
503+
);
504+
soln.push_sens(y_out, t, &[ypk_out, ypy0_out]);
505+
}
506+
(problem, soln)
507+
}
508+
425509
#[cfg(test)]
426510
mod tests {
427511
#[cfg(feature = "diffsl-llvm")]

diffsol/src/ode_solver/builder.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,50 @@ where
607607
}
608608
}
609609

610+
pub fn out_sens_implicit<F, G, H>(
611+
self,
612+
out: F,
613+
out_jac: G,
614+
out_sens: H,
615+
nout: usize,
616+
) -> OdeBuilder<M, Rhs, Init, Mass, Root, ClosureWithSens<M, F, G, H>>
617+
where
618+
F: Fn(&M::V, &M::V, M::T, &mut M::V),
619+
G: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
620+
H: Fn(&M::V, &M::V, M::T, &M::V, &mut M::V),
621+
{
622+
let nstates = 0;
623+
OdeBuilder::<M, Rhs, Init, Mass, Root, ClosureWithSens<M, F, G, H>> {
624+
rhs: self.rhs,
625+
init: self.init,
626+
mass: self.mass,
627+
root: self.root,
628+
out: Some(ClosureWithSens::new(
629+
out,
630+
out_jac,
631+
out_sens,
632+
nstates,
633+
nstates,
634+
nout,
635+
self.ctx.clone(),
636+
)),
637+
t0: self.t0,
638+
h0: self.h0,
639+
rtol: self.rtol,
640+
atol: self.atol,
641+
sens_atol: self.sens_atol,
642+
sens_rtol: self.sens_rtol,
643+
out_rtol: self.out_rtol,
644+
out_atol: self.out_atol,
645+
param_rtol: self.param_rtol,
646+
param_atol: self.param_atol,
647+
p: self.p,
648+
use_coloring: self.use_coloring,
649+
integrate_out: self.integrate_out,
650+
ctx: self.ctx,
651+
}
652+
}
653+
610654
#[allow(clippy::type_complexity)]
611655
pub fn out_adjoint_implicit<F, G, H, I>(
612656
self,

diffsol/src/ode_solver/method.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ mod test {
406406
matrix::dense_nalgebra_serial::NalgebraMat,
407407
ode_equations::test_models::exponential_decay::{
408408
exponential_decay_problem, exponential_decay_problem_adjoint,
409-
exponential_decay_problem_sens,
409+
exponential_decay_problem_sens, exponential_decay_problem_sens_with_out,
410410
},
411411
scale, AdjointOdeSolverMethod, DenseMatrix, NalgebraLU, NalgebraVec, OdeEquations,
412412
OdeSolverMethod, Op, SensitivitiesOdeSolverMethod, Vector, VectorView,
@@ -501,6 +501,30 @@ mod test {
501501
}
502502
}
503503

504+
#[test]
505+
fn test_dense_solve_sensitivities_with_out() {
506+
let (problem, soln) = exponential_decay_problem_sens_with_out::<NalgebraMat<f64>>(false);
507+
let mut s = problem.bdf_sens::<NalgebraLU<f64>>().unwrap();
508+
509+
let t_eval = soln.solution_points.iter().map(|p| p.t).collect::<Vec<_>>();
510+
let (y, sens) = s.solve_dense_sensitivities(t_eval.as_slice()).unwrap();
511+
for (i, soln_pt) in soln.solution_points.iter().enumerate() {
512+
let y_i = y.column(i).into_owned();
513+
y_i.assert_eq_norm(&soln_pt.state, &problem.atol, problem.rtol, 15.0);
514+
}
515+
for (j, soln_pts) in soln.sens_solution_points.unwrap().iter().enumerate() {
516+
for (i, soln_pt) in soln_pts.iter().enumerate() {
517+
let sens_i = sens[j].column(i).into_owned();
518+
sens_i.assert_eq_norm(
519+
&soln_pt.state,
520+
problem.sens_atol.as_ref().unwrap(),
521+
problem.sens_rtol.unwrap(),
522+
15.0,
523+
);
524+
}
525+
}
526+
}
527+
504528
#[test]
505529
fn test_solve_adjoint() {
506530
let (problem, soln) = exponential_decay_problem_adjoint::<NalgebraMat<f64>>(true);

diffsol/src/ode_solver/sensitivities.rs

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
use crate::{
22
error::DiffsolError, error::OdeSolverError, ode_solver_error, AugmentedOdeSolverMethod,
3-
Context, DefaultDenseMatrix, DefaultSolver, DenseMatrix, OdeEquationsImplicitSens,
4-
OdeSolverStopReason, Op, SensEquations, VectorViewMut,
3+
Context, DefaultDenseMatrix, DefaultSolver, DenseMatrix, NonLinearOp, NonLinearOpJacobian,
4+
NonLinearOpSens, OdeEquationsImplicitSens, OdeSolverStopReason, Op, SensEquations, Vector,
5+
VectorViewMut,
56
};
7+
use num_traits::{One, Zero};
8+
use std::ops::AddAssign;
69

710
pub trait SensitivitiesOdeSolverMethod<'a, Eqn>:
811
AugmentedOdeSolverMethod<'a, Eqn, SensEquations<'a, Eqn>>
@@ -36,7 +39,19 @@ where
3639
"Cannot integrate out when solving for sensitivities"
3740
));
3841
}
39-
let nrows = self.problem().eqn.rhs().nstates();
42+
let (mut tmp_nout, mut tmp_nparms, nrows) = if let Some(out) = self.problem().eqn.out() {
43+
(
44+
Some(Eqn::V::zeros(out.nout(), self.problem().context().clone())),
45+
Some(Eqn::V::zeros(
46+
out.nparams(),
47+
self.problem().context().clone(),
48+
)),
49+
out.nout(),
50+
)
51+
} else {
52+
(None, None, self.problem().eqn.rhs().nout())
53+
};
54+
4055
let mut ret = self
4156
.problem()
4257
.context()
@@ -62,10 +77,27 @@ where
6277
step_reason = self.step()?;
6378
}
6479
let y = self.interpolate(*t)?;
65-
ret.column_mut(i).copy_from(&y);
66-
let s = self.interpolate_sens(*t)?;
67-
for (j, s_j) in s.iter().enumerate() {
68-
ret_sens[j].column_mut(i).copy_from(s_j);
80+
let mut s = self.interpolate_sens(*t)?;
81+
if let Some(out) = self.problem().eqn.out() {
82+
let tmp_nout = tmp_nout.as_mut().unwrap();
83+
let tmp_nparams = tmp_nparms.as_mut().unwrap();
84+
out.call_inplace(&y, *t, tmp_nout);
85+
ret.column_mut(i).copy_from(tmp_nout);
86+
for (j, s_j) in s.iter_mut().enumerate() {
87+
// compute J * s_j + dF/dp * e_j where e_j is the jth basis vector
88+
tmp_nparams.set_index(j, Eqn::T::one());
89+
out.jac_mul_inplace(&y, *t, s_j, tmp_nout);
90+
s_j.copy_from(tmp_nout);
91+
out.sens_mul_inplace(&y, *t, tmp_nparams, tmp_nout);
92+
s_j.add_assign(&*tmp_nout);
93+
ret_sens[j].column_mut(i).copy_from(s_j);
94+
tmp_nparams.set_index(j, Eqn::T::zero());
95+
}
96+
} else {
97+
ret.column_mut(i).copy_from(&y);
98+
for (j, s_j) in s.iter().enumerate() {
99+
ret_sens[j].column_mut(i).copy_from(s_j);
100+
}
69101
}
70102
}
71103

@@ -74,11 +106,33 @@ where
74106
step_reason = self.step()?;
75107
}
76108
let y = self.state().y;
77-
ret.column_mut(t_eval.len() - 1).copy_from(y);
78109
let s = self.state().s;
79-
for (j, s_j) in s.iter().enumerate() {
80-
ret_sens[j].column_mut(t_eval.len() - 1).copy_from(s_j);
110+
let mut s_tmp = tmp_nout.clone();
111+
let i = t_eval.len() - 1;
112+
let t = t_eval.last().unwrap();
113+
if let Some(out) = self.problem().eqn.out() {
114+
let tmp_nout = tmp_nout.as_mut().unwrap();
115+
let tmp_nparams = tmp_nparms.as_mut().unwrap();
116+
let s_tmp = s_tmp.as_mut().unwrap();
117+
out.call_inplace(y, *t, tmp_nout);
118+
ret.column_mut(i).copy_from(tmp_nout);
119+
for (j, s_j) in s.iter().enumerate() {
120+
// compute J * s_j + dF/dp * e_j where e_j is the jth basis vector
121+
tmp_nparams.set_index(j, Eqn::T::one());
122+
out.jac_mul_inplace(y, *t, s_j, tmp_nout);
123+
s_tmp.copy_from(tmp_nout);
124+
out.sens_mul_inplace(y, *t, tmp_nparams, tmp_nout);
125+
s_tmp.add_assign(&*tmp_nout);
126+
ret_sens[j].column_mut(i).copy_from(s_tmp);
127+
tmp_nparams.set_index(j, Eqn::T::zero());
128+
}
129+
} else {
130+
ret.column_mut(i).copy_from(y);
131+
for (j, s_j) in s.iter().enumerate() {
132+
ret_sens[j].column_mut(i).copy_from(s_j);
133+
}
81134
}
135+
82136
Ok((ret, ret_sens))
83137
}
84138
}

0 commit comments

Comments
 (0)