Skip to content

Commit 35b2bf9

Browse files
feat: add non-allocating interpolation functions for odesolvermethod (#190)
* feat: non-allocating interpolate funcs (bdf, rk) * feat: inplace interpolate fncs for sdirk * swap to a simple solution for growth * remove import * use interpolate_inplace * add tests for interpolate_out * improve test_interpolate * fix clippy * more interpolate tests
1 parent 3cd7f17 commit 35b2bf9

File tree

10 files changed

+502
-166
lines changed

10 files changed

+502
-166
lines changed

diffsol/benches/ode_solvers.rs

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ fn criterion_benchmark(c: &mut Criterion) {
1515
c.bench_function(stringify!($name), |b| {
1616
b.iter(|| {
1717
let (problem, soln) = $model_problem::<$matrix>(false);
18-
benchmarks::$solver::<_, $linear_solver<_>>(
19-
&problem,
20-
soln.solution_points.last().unwrap().t,
21-
);
18+
let t_evals = soln
19+
.solution_points
20+
.iter()
21+
.map(|sp| sp.t)
22+
.collect::<Vec<_>>();
23+
benchmarks::$solver::<_, $linear_solver<_>>(&problem, &t_evals);
2224
})
2325
});
2426
};
@@ -29,7 +31,12 @@ fn criterion_benchmark(c: &mut Criterion) {
2931
c.bench_function(stringify!($name), |b| {
3032
b.iter(|| {
3133
let (problem, soln) = $model_problem::<$matrix>(false);
32-
benchmarks::$solver::<_>(&problem, soln.solution_points.last().unwrap().t);
34+
let t_evals = soln
35+
.solution_points
36+
.iter()
37+
.map(|sp| sp.t)
38+
.collect::<Vec<_>>();
39+
benchmarks::$solver::<_>(&problem, &t_evals);
3340
})
3441
});
3542
};
@@ -144,7 +151,8 @@ fn criterion_benchmark(c: &mut Criterion) {
144151
$(c.bench_function(concat!(stringify!($name), "_", $N), |b| {
145152
b.iter(|| {
146153
let (problem, soln) = $model_problem::<$matrix>(false, $N);
147-
benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t);
154+
let t_evals = soln.solution_points.iter().map(|sp| sp.t).collect::<Vec<_>>();
155+
benchmarks::$solver::<_, $linear_solver<_>>(&problem, &t_evals);
148156
})
149157
});)+
150158
};
@@ -196,12 +204,12 @@ fn criterion_benchmark(c: &mut Criterion) {
196204
use diffsol::ode_equations::test_models::robertson::*;
197205
use diffsol::LlvmModule;
198206
let (problem, soln) = robertson_diffsl_problem::<$matrix, LlvmModule>();
199-
b.iter(|| {
200-
benchmarks::$solver::<_, $linear_solver<_>>(
201-
&problem,
202-
soln.solution_points.last().unwrap().t,
203-
)
204-
})
207+
let t_evals = soln
208+
.solution_points
209+
.iter()
210+
.map(|sp| sp.t)
211+
.collect::<Vec<_>>();
212+
b.iter(|| benchmarks::$solver::<_, $linear_solver<_>>(&problem, &t_evals))
205213
});
206214
};
207215
}
@@ -218,7 +226,8 @@ fn criterion_benchmark(c: &mut Criterion) {
218226
$(c.bench_function(concat!(stringify!($name), "_", $N), |b| {
219227
b.iter(|| {
220228
let (problem, soln) = $model_problem::<$matrix, $N>();
221-
benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t)
229+
let t_evals = soln.solution_points.iter().map(|sp| sp.t).collect::<Vec<_>>();
230+
benchmarks::$solver::<_, $linear_solver<_>>(&problem, &t_evals)
222231
})
223232
});)+
224233
};
@@ -268,7 +277,8 @@ fn criterion_benchmark(c: &mut Criterion) {
268277
$(c.bench_function(concat!(stringify!($name), "_", $N), |b| {
269278
b.iter(|| {
270279
let (problem, soln) = $model_problem::<$matrix, $N>();
271-
benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t)
280+
let t_evals = soln.solution_points.iter().map(|sp| sp.t).collect::<Vec<_>>();
281+
benchmarks::$solver::<_, $linear_solver<_>>(&problem, &t_evals)
272282
})
273283
});)+
274284
};
@@ -320,8 +330,9 @@ fn criterion_benchmark(c: &mut Criterion) {
320330
use diffsol::ode_equations::test_models::heat2d::*;
321331
use diffsol::LlvmModule;
322332
let (problem, soln) = $model_problem::<$matrix, LlvmModule, $N>();
333+
let t_evals = soln.solution_points.iter().map(|sp| sp.t).collect::<Vec<_>>();
323334
b.iter(|| {
324-
benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t)
335+
benchmarks::$solver::<_, $linear_solver<_>>(&problem, &t_evals)
325336
})
326337
});)+
327338
};
@@ -334,8 +345,9 @@ fn criterion_benchmark(c: &mut Criterion) {
334345
use diffsol::ode_equations::test_models::heat1d::*;
335346
use diffsol::LlvmModule;
336347
let (problem, soln) = $model_problem::<$matrix, LlvmModule, $N>();
348+
let t_evals = soln.solution_points.iter().map(|sp| sp.t).collect::<Vec<_>>();
337349
b.iter(|| {
338-
benchmarks::$solver::<_>(&problem, soln.solution_points.last().unwrap().t)
350+
benchmarks::$solver::<_>(&problem, &t_evals)
339351
})
340352
});)+
341353
};
@@ -364,6 +376,17 @@ fn criterion_benchmark(c: &mut Criterion) {
364376
80
365377
);
366378

379+
bench_diffsl_heat1d!(
380+
nalgebra_tsit45_diffsl_heat1d,
381+
tsit45,
382+
heat1d_diffsl_problem,
383+
NalgebraMat<f64>,
384+
10,
385+
20,
386+
40,
387+
80
388+
);
389+
367390
macro_rules! bench_sundials {
368391
($name:ident, $solver:ident) => {
369392
#[cfg(feature = "sundials")]
@@ -409,8 +432,9 @@ fn criterion_benchmark(c: &mut Criterion) {
409432
use diffsol::ode_equations::test_models::foodweb::*;
410433
use diffsol::LlvmModule;
411434
let (problem, soln) = foodweb_diffsl_problem::<$matrix, LlvmModule, $N>();
435+
let t_evals = soln.solution_points.iter().map(|sp| sp.t).collect::<Vec<_>>();
412436
b.iter(|| {
413-
benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t)
437+
benchmarks::$solver::<_, $linear_solver<_>>(&problem, &t_evals)
414438
})
415439
});)+
416440

@@ -442,7 +466,7 @@ mod benchmarks {
442466
};
443467

444468
// bdf
445-
pub fn bdf<Eqn, LS>(problem: &OdeSolverProblem<Eqn>, t: Eqn::T)
469+
pub fn bdf<Eqn, LS>(problem: &OdeSolverProblem<Eqn>, t_evals: &[Eqn::T])
446470
where
447471
Eqn: OdeEquationsImplicit,
448472
Eqn::M: Matrix + DefaultSolver,
@@ -452,10 +476,10 @@ mod benchmarks {
452476
for<'a> &'a Eqn::M: MatrixRef<Eqn::M>,
453477
{
454478
let mut s = problem.bdf::<LS>().unwrap();
455-
let _y = s.solve(t);
479+
let _y = s.solve_dense(t_evals);
456480
}
457481

458-
pub fn esdirk34<Eqn, LS>(problem: &OdeSolverProblem<Eqn>, t: Eqn::T)
482+
pub fn esdirk34<Eqn, LS>(problem: &OdeSolverProblem<Eqn>, t_evals: &[Eqn::T])
459483
where
460484
Eqn: OdeEquationsImplicit,
461485
Eqn::M: Matrix + DefaultSolver,
@@ -465,10 +489,10 @@ mod benchmarks {
465489
for<'a> &'a Eqn::M: MatrixRef<Eqn::M>,
466490
{
467491
let mut s = problem.esdirk34::<LS>().unwrap();
468-
let _y = s.solve(t);
492+
let _y = s.solve_dense(t_evals);
469493
}
470494

471-
pub fn tr_bdf2<Eqn, LS>(problem: &OdeSolverProblem<Eqn>, t: Eqn::T)
495+
pub fn tr_bdf2<Eqn, LS>(problem: &OdeSolverProblem<Eqn>, t_evals: &[Eqn::T])
472496
where
473497
Eqn: OdeEquationsImplicit,
474498
Eqn::M: Matrix + DefaultSolver,
@@ -478,10 +502,10 @@ mod benchmarks {
478502
for<'a> &'a Eqn::M: MatrixRef<Eqn::M>,
479503
{
480504
let mut s = problem.tr_bdf2::<LS>().unwrap();
481-
let _y = s.solve(t);
505+
let _y = s.solve_dense(t_evals);
482506
}
483507

484-
pub fn tsit45<Eqn>(problem: &OdeSolverProblem<Eqn>, t: Eqn::T)
508+
pub fn tsit45<Eqn>(problem: &OdeSolverProblem<Eqn>, t_evals: &[Eqn::T])
485509
where
486510
Eqn: OdeEquationsImplicit,
487511
Eqn::M: Matrix + DefaultSolver,
@@ -490,6 +514,6 @@ mod benchmarks {
490514
for<'a> &'a Eqn::M: MatrixRef<Eqn::M>,
491515
{
492516
let mut s = problem.tsit45().unwrap();
493-
let _y = s.solve(t);
517+
let _y = s.solve_dense(t_evals);
494518
}
495519
}

diffsol/src/error.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ pub enum OdeSolverError {
7070
MassMatrixNotSupported,
7171
#[error("Stop time is at the current state time")]
7272
StopTimeAtCurrentTime,
73+
#[error("Interpolation vector is not the correct length, expected {expected}, got {found}")]
74+
InterpolationVectorWrongSize { expected: usize, found: usize },
75+
#[error("Number of sensitivities does not match number of parameters")]
76+
SensitivityCountMismatch { expected: usize, found: usize },
7377
#[error("Interpolation time is after current time")]
7478
InterpolationTimeAfterCurrentTime,
7579
#[error("Interpolation time is not within the current step. Step size is zero after calling state_mut()")]

diffsol/src/nonlinear_solver/root.rs

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,17 @@ pub struct RootFinder<V: Vector> {
1414
g0: RefCell<V>,
1515
g1: RefCell<V>,
1616
gmid: RefCell<V>,
17+
ymid: RefCell<V>,
1718
}
1819

1920
impl<V: Vector> RootFinder<V> {
20-
pub fn new(n: usize, ctx: V::C) -> Self {
21+
pub fn new(nroots: usize, nstates: usize, ctx: V::C) -> Self {
2122
Self {
2223
t0: RefCell::new(V::T::zero()),
23-
g0: RefCell::new(V::zeros(n, ctx.clone())),
24-
g1: RefCell::new(V::zeros(n, ctx.clone())),
25-
gmid: RefCell::new(V::zeros(n, ctx)),
24+
g0: RefCell::new(V::zeros(nroots, ctx.clone())),
25+
g1: RefCell::new(V::zeros(nroots, ctx.clone())),
26+
gmid: RefCell::new(V::zeros(nroots, ctx.clone())),
27+
ymid: RefCell::new(V::zeros(nstates, ctx)),
2628
}
2729
}
2830

@@ -42,14 +44,15 @@ impl<V: Vector> RootFinder<V> {
4244
/// We find the root of a function using the method proposed by Sundials [docs](https://sundials.readthedocs.io/en/latest/cvode/Mathematics_link.html#rootfinding)
4345
pub fn check_root(
4446
&self,
45-
interpolate: &impl Fn(V::T) -> Result<V, DiffsolError>,
47+
interpolate_inplace: &impl Fn(V::T, &mut V) -> Result<(), DiffsolError>,
4648
root_fn: &impl NonLinearOp<V = V, T = V::T>,
4749
y: &V,
4850
t: V::T,
4951
) -> Option<V::T> {
5052
let g1 = &mut *self.g1.borrow_mut();
5153
let g0 = &mut *self.g0.borrow_mut();
5254
let gmid = &mut *self.gmid.borrow_mut();
55+
let ymid = &mut *self.ymid.borrow_mut();
5356
root_fn.call_inplace(y, t, g1);
5457

5558
let (rootfnd, _gfracmax, imax) = g0.root_finding(g1);
@@ -105,8 +108,8 @@ impl<V: Vector> RootFinder<V> {
105108
t_mid = t1 - fracsub * (t1 - t0);
106109
}
107110

108-
let ymid = interpolate(t_mid).unwrap();
109-
root_fn.call_inplace(&ymid, t_mid, gmid);
111+
interpolate_inplace(t_mid, ymid).unwrap();
112+
root_fn.call_inplace(ymid, t_mid, gmid);
110113

111114
let (rootfnd, _gfracmax, imax_i32) = g0.root_finding(gmid);
112115
let lower = imax_i32 >= 0;
@@ -158,8 +161,10 @@ mod tests {
158161
type V = NalgebraVec<f64>;
159162
type M = NalgebraMat<f64>;
160163
let ctx = NalgebraContext;
161-
let interpolate =
162-
|t: f64| -> Result<V, DiffsolError> { Ok(Vector::from_vec(vec![t], ctx.clone())) };
164+
let interpolate_inplace = |t: f64, y: &mut V| -> Result<(), DiffsolError> {
165+
y[0] = t;
166+
Ok(())
167+
};
163168
let p = V::zeros(0, ctx.clone());
164169
let root_fn = ClosureNoJac::<M, _>::new(
165170
|y: &V, _p: &V, _t: f64, g: &mut V| {
@@ -173,21 +178,21 @@ mod tests {
173178
let root_fn = ParameterisedOp::new(&root_fn, &p);
174179

175180
// check no root
176-
let root_finder = RootFinder::new(1, ctx.clone());
181+
let root_finder = RootFinder::new(1, 1, ctx.clone());
177182
root_finder.init(&root_fn, &Vector::from_vec(vec![0.0], ctx.clone()), 0.0);
178183
let root = root_finder.check_root(
179-
&interpolate,
184+
&interpolate_inplace,
180185
&root_fn,
181186
&Vector::from_vec(vec![0.3], ctx.clone()),
182187
0.3,
183188
);
184189
assert_eq!(root, None);
185190

186191
// check root
187-
let root_finder = RootFinder::new(1, ctx.clone());
192+
let root_finder = RootFinder::new(1, 1, ctx.clone());
188193
root_finder.init(&root_fn, &Vector::from_vec(vec![0.0], ctx.clone()), 0.0);
189194
let root = root_finder.check_root(
190-
&interpolate,
195+
&interpolate_inplace,
191196
&root_fn,
192197
&Vector::from_vec(vec![1.3], ctx.clone()),
193198
1.3,

0 commit comments

Comments
 (0)