Skip to content

Commit 3cd7f17

Browse files
feat: non-allocating write outs (#189)
* feat: non-allocating dense write out * format * feat: non-allocating solve write_out test * add some tests and tidy * use named constants * format
1 parent 73bd6b7 commit 3cd7f17

File tree

4 files changed

+137
-40
lines changed

4 files changed

+137
-40
lines changed

diffsol/src/matrix/dense_faer_serial.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,14 @@ impl<T: Scalar> DenseMatrix for FaerMat<T> {
177177
Self { data, context: ctx }
178178
}
179179

180+
fn resize_cols(&mut self, ncols: IndexType) {
181+
if ncols == self.ncols() {
182+
return;
183+
}
184+
let nrows = self.nrows();
185+
self.data.resize_with(nrows, ncols, |_, _| T::zero());
186+
}
187+
180188
fn get_index(&self, i: IndexType, j: IndexType) -> Self::T {
181189
self.data[(i, j)]
182190
}
@@ -384,4 +392,9 @@ mod tests {
384392
fn test_partition_indices_by_zero_diagonal() {
385393
super::super::tests::test_partition_indices_by_zero_diagonal::<FaerMat<f64>>();
386394
}
395+
396+
#[test]
397+
fn test_resize_cols() {
398+
super::super::tests::test_resize_cols::<FaerMat<f64>>();
399+
}
387400
}

diffsol/src/matrix/dense_nalgebra_serial.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,13 @@ impl<T: Scalar> DenseMatrix for NalgebraMat<T> {
270270
self.data.gemm(alpha, &a.data, &b.data, beta);
271271
}
272272

273+
fn resize_cols(&mut self, ncols: IndexType) {
274+
if ncols == self.ncols() {
275+
return;
276+
}
277+
self.data.resize_horizontally_mut(ncols, Self::T::zero());
278+
}
279+
273280
fn get_index(&self, i: IndexType, j: IndexType) -> Self::T {
274281
self.data[(i, j)]
275282
}
@@ -347,4 +354,9 @@ mod tests {
347354
fn test_partition_indices_by_zero_diagonal() {
348355
super::super::tests::test_partition_indices_by_zero_diagonal::<NalgebraMat<f64>>();
349356
}
357+
358+
#[test]
359+
fn test_resize_cols() {
360+
super::super::tests::test_resize_cols::<NalgebraMat<f64>>();
361+
}
350362
}

diffsol/src/matrix/mod.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,9 @@ pub trait DenseMatrix:
326326
ret
327327
}
328328

329+
/// Resize the number of columns in the matrix. Existing data is preserved, new elements are uninitialized
330+
fn resize_cols(&mut self, ncols: IndexType);
331+
329332
/// creates a new matrix from a vector of values, which are assumed
330333
/// to be in column-major order
331334
fn from_vec(nrows: IndexType, ncols: IndexType, data: Vec<Self::T>, ctx: Self::C) -> Self;
@@ -390,4 +393,33 @@ mod tests {
390393
assert_eq!(a.get_index(1, 0), M::T::from(3.0));
391394
assert_eq!(a.get_index(1, 1), M::T::from(10.0));
392395
}
396+
397+
pub fn test_resize_cols<M: DenseMatrix>() {
398+
let mut a = M::zeros(2, 2, Default::default());
399+
a.set_index(0, 0, M::T::from(1.0));
400+
a.set_index(0, 1, M::T::from(2.0));
401+
a.set_index(1, 0, M::T::from(3.0));
402+
a.set_index(1, 1, M::T::from(4.0));
403+
404+
a.resize_cols(3);
405+
assert_eq!(a.ncols(), 3);
406+
assert_eq!(a.nrows(), 2);
407+
assert_eq!(a.get_index(0, 0), M::T::from(1.0));
408+
assert_eq!(a.get_index(0, 1), M::T::from(2.0));
409+
assert_eq!(a.get_index(1, 0), M::T::from(3.0));
410+
assert_eq!(a.get_index(1, 1), M::T::from(4.0));
411+
412+
a.set_index(0, 2, M::T::from(5.0));
413+
a.set_index(1, 2, M::T::from(6.0));
414+
assert_eq!(a.get_index(0, 2), M::T::from(5.0));
415+
assert_eq!(a.get_index(1, 2), M::T::from(6.0));
416+
417+
a.resize_cols(2);
418+
assert_eq!(a.ncols(), 2);
419+
assert_eq!(a.nrows(), 2);
420+
assert_eq!(a.get_index(0, 0), M::T::from(1.0));
421+
assert_eq!(a.get_index(0, 1), M::T::from(2.0));
422+
assert_eq!(a.get_index(1, 0), M::T::from(3.0));
423+
assert_eq!(a.get_index(1, 1), M::T::from(4.0));
424+
}
393425
}

diffsol/src/ode_solver/method.rs

Lines changed: 80 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@ use crate::{
55
ode_solver_error,
66
scalar::Scalar,
77
AugmentedOdeEquations, Checkpointing, Context, DefaultDenseMatrix, DenseMatrix,
8-
HermiteInterpolator, NonLinearOp, OdeEquations, OdeSolverConfig, OdeSolverProblem,
9-
OdeSolverState, Op, StateRef, StateRefMut, Vector, VectorViewMut,
8+
HermiteInterpolator, MatrixCommon, NonLinearOp, OdeEquations, OdeSolverConfig,
9+
OdeSolverProblem, OdeSolverState, Op, StateRef, StateRefMut, Vector, VectorViewMut,
1010
};
11+
use nalgebra::ComplexField;
1112

1213
#[derive(Debug, PartialEq)]
1314
pub enum OdeSolverStopReason<T: Scalar> {
@@ -120,27 +121,20 @@ where
120121
Self: Sized,
121122
{
122123
let mut ret_t = Vec::new();
123-
let mut ret_y = Vec::new();
124+
let (mut ret_y, mut tmp_nout) = allocate_return(self)?;
124125

125126
// do the main loop
126-
write_out(self, &mut ret_y, &mut ret_t);
127+
write_out(self, &mut ret_y, &mut ret_t, final_time, &mut tmp_nout);
127128
self.set_stop_time(final_time)?;
128129
while self.step()? != OdeSolverStopReason::TstopReached {
129-
write_out(self, &mut ret_y, &mut ret_t);
130+
write_out(self, &mut ret_y, &mut ret_t, final_time, &mut tmp_nout);
130131
}
131132

132133
// store the final step
133-
write_out(self, &mut ret_y, &mut ret_t);
134+
write_out(self, &mut ret_y, &mut ret_t, final_time, &mut tmp_nout);
134135
let ntimes = ret_t.len();
135-
let nrows = ret_y[0].len();
136-
let mut ret_y_matrix = self
137-
.problem()
138-
.context()
139-
.dense_mat_zeros::<Eqn::V>(nrows, ntimes);
140-
for (i, y) in ret_y.iter().enumerate() {
141-
ret_y_matrix.column_mut(i).copy_from(y);
142-
}
143-
Ok((ret_y_matrix, ret_t))
136+
ret_y.resize_cols(ntimes);
137+
Ok((ret_y, ret_t))
144138
}
145139

146140
/// Using the provided state, solve the problem up to time `t_eval[t_eval.len()-1]`
@@ -154,7 +148,7 @@ where
154148
Eqn::V: DefaultDenseMatrix,
155149
Self: Sized,
156150
{
157-
let mut ret = dense_allocate_return(self, t_eval)?;
151+
let (mut ret, mut tmp_nout) = dense_allocate_return(self, t_eval)?;
158152

159153
// do loop
160154
self.set_stop_time(t_eval[t_eval.len() - 1])?;
@@ -163,7 +157,7 @@ where
163157
while self.state().t < *t {
164158
step_reason = self.step()?;
165159
}
166-
dense_write_out(self, &mut ret, t_eval, i)?;
160+
dense_write_out(self, &mut ret, t_eval, i, &mut tmp_nout)?;
167161
}
168162
assert_eq!(step_reason, OdeSolverStopReason::TstopReached);
169163
Ok(ret)
@@ -187,7 +181,7 @@ where
187181
Self: Sized,
188182
{
189183
let mut ret_t = Vec::new();
190-
let mut ret_y = Vec::new();
184+
let (mut ret_y, mut tmp_nout) = allocate_return(self)?;
191185
let max_steps_between_checkpoints = max_steps_between_checkpoints.unwrap_or(500);
192186

193187
// allocate checkpoint info
@@ -199,10 +193,10 @@ where
199193
let mut ydots = vec![self.state().dy.clone()];
200194

201195
// do the main loop, saving checkpoints
202-
write_out(self, &mut ret_y, &mut ret_t);
196+
write_out(self, &mut ret_y, &mut ret_t, final_time, &mut tmp_nout);
203197
self.set_stop_time(final_time)?;
204198
while self.step()? != OdeSolverStopReason::TstopReached {
205-
write_out(self, &mut ret_y, &mut ret_t);
199+
write_out(self, &mut ret_y, &mut ret_t, final_time, &mut tmp_nout);
206200
ts.push(self.state().t);
207201
ys.push(self.state().y.clone());
208202
ydots.push(self.state().dy.clone());
@@ -217,16 +211,9 @@ where
217211
}
218212

219213
// store the final step
220-
write_out(self, &mut ret_y, &mut ret_t);
214+
write_out(self, &mut ret_y, &mut ret_t, final_time, &mut tmp_nout);
221215
let ntimes = ret_t.len();
222-
let nrows = ret_y[0].len();
223-
let mut ret_y_matrix = self
224-
.problem()
225-
.context()
226-
.dense_mat_zeros::<Eqn::V>(nrows, ntimes);
227-
for (i, y) in ret_y.iter().enumerate() {
228-
ret_y_matrix.column_mut(i).copy_from(y);
229-
}
216+
ret_y.resize_cols(ntimes);
230217

231218
// add final checkpoint
232219
ts.push(self.state().t);
@@ -243,7 +230,7 @@ where
243230
Some(last_segment),
244231
);
245232

246-
Ok((checkpointer, ret_y_matrix, ret_t))
233+
Ok((checkpointer, ret_y, ret_t))
247234
}
248235

249236
/// Solve the problem and write out the solution at the given timepoints, using checkpointing so that
@@ -265,7 +252,7 @@ where
265252
Eqn::V: DefaultDenseMatrix,
266253
Self: Sized,
267254
{
268-
let mut ret = dense_allocate_return(self, t_eval)?;
255+
let (mut ret, mut tmp_nout) = dense_allocate_return(self, t_eval)?;
269256
let max_steps_between_checkpoints = max_steps_between_checkpoints.unwrap_or(500);
270257

271258
// allocate checkpoint info
@@ -296,7 +283,7 @@ where
296283
ydots.clear();
297284
}
298285
}
299-
dense_write_out(self, &mut ret, t_eval, i)?;
286+
dense_write_out(self, &mut ret, t_eval, i, &mut tmp_nout)?;
300287
}
301288
assert_eq!(step_reason, OdeSolverStopReason::TstopReached);
302289

@@ -334,6 +321,7 @@ fn dense_write_out<'a, Eqn: OdeEquations + 'a, S: OdeSolverMethod<'a, Eqn>>(
334321
y_out: &mut <Eqn::V as DefaultDenseMatrix>::M,
335322
t_eval: &[Eqn::T],
336323
i: usize,
324+
tmp_nout: &mut Eqn::V,
337325
) -> Result<(), DiffsolError>
338326
where
339327
Eqn::V: DefaultDenseMatrix,
@@ -346,7 +334,10 @@ where
346334
} else {
347335
let y = s.interpolate(t)?;
348336
match s.problem().eqn.out() {
349-
Some(out) => y_out.copy_from(&out.call(&y, t_eval[i])),
337+
Some(out) => {
338+
out.call_inplace(&y, t_eval[i], tmp_nout);
339+
y_out.copy_from(tmp_nout)
340+
}
350341
None => y_out.copy_from(&y),
351342
}
352343
}
@@ -357,30 +348,74 @@ where
357348
/// This function is used by the `solve` method to write out the solution at a given timepoint.
358349
fn write_out<'a, Eqn: OdeEquations + 'a, S: OdeSolverMethod<'a, Eqn>>(
359350
s: &S,
360-
ret_y: &mut Vec<Eqn::V>,
351+
ret_y: &mut <Eqn::V as DefaultDenseMatrix>::M,
361352
ret_t: &mut Vec<Eqn::T>,
362-
) {
353+
final_time: Eqn::T,
354+
tmp_nout: &mut Eqn::V,
355+
) where
356+
Eqn::V: DefaultDenseMatrix,
357+
{
363358
let t = s.state().t;
364359
let y = s.state().y;
365360
ret_t.push(t);
361+
let i = ret_t.len() - 1;
362+
if i >= ret_y.ncols() {
363+
const GROWTH_FACTOR: f64 = 1.5;
364+
let remaining: f64 = (Eqn::T::from(GROWTH_FACTOR) * (final_time - ret_t[i - 1])
365+
/ (ret_t[i] - ret_t[i - 1]))
366+
.ceil()
367+
.into();
368+
let n = ret_y.ncols() + (remaining as usize);
369+
ret_y.resize_cols(n);
370+
}
371+
let mut ret_y_col = ret_y.column_mut(i);
366372
match s.problem().eqn.out() {
367373
Some(out) => {
368374
if s.problem().integrate_out {
369-
ret_y.push(s.state().g.clone());
375+
ret_y_col.copy_from(s.state().g);
370376
} else {
371-
ret_y.push(out.call(y, t));
377+
out.call_inplace(y, t, tmp_nout);
378+
ret_y_col.copy_from(tmp_nout);
372379
}
373380
}
374-
None => ret_y.push(y.clone()),
381+
None => ret_y_col.copy_from(y),
375382
}
376383
}
377384

385+
/// Utility function to allocate the return matrix for the `solve`
386+
/// method
387+
fn allocate_return<'a, Eqn: OdeEquations + 'a, S: OdeSolverMethod<'a, Eqn>>(
388+
s: &S,
389+
) -> Result<(<Eqn::V as DefaultDenseMatrix>::M, Eqn::V), DiffsolError>
390+
where
391+
Eqn::V: DefaultDenseMatrix,
392+
{
393+
let nrows = if s.problem().eqn.out().is_some() {
394+
s.problem().eqn.out().unwrap().nout()
395+
} else {
396+
s.problem().eqn.rhs().nstates()
397+
};
398+
const INITIAL_NCOLS: usize = 10;
399+
let ret = s
400+
.problem()
401+
.context()
402+
.dense_mat_zeros::<Eqn::V>(nrows, INITIAL_NCOLS);
403+
404+
// check t_eval is increasing and all values are greater than or equal to the current time
405+
let tmp_nout = if let Some(out) = s.problem().eqn.out() {
406+
Eqn::V::zeros(out.nout(), s.problem().context().clone())
407+
} else {
408+
Eqn::V::zeros(0, s.problem().context().clone())
409+
};
410+
Ok((ret, tmp_nout))
411+
}
412+
378413
/// Utility function to allocate the return matrix for the `solve_dense`
379414
/// and `solve_dense_sensitivities` methods.
380415
fn dense_allocate_return<'a, Eqn: OdeEquations + 'a, S: OdeSolverMethod<'a, Eqn>>(
381416
s: &S,
382417
t_eval: &[Eqn::T],
383-
) -> Result<<Eqn::V as DefaultDenseMatrix>::M, DiffsolError>
418+
) -> Result<(<Eqn::V as DefaultDenseMatrix>::M, Eqn::V), DiffsolError>
384419
where
385420
Eqn::V: DefaultDenseMatrix,
386421
{
@@ -399,7 +434,12 @@ where
399434
if t_eval.windows(2).any(|w| w[0] > w[1] || w[0] < t0) {
400435
return Err(ode_solver_error!(InvalidTEval));
401436
}
402-
Ok(ret)
437+
let tmp_nout = if let Some(out) = s.problem().eqn.out() {
438+
Eqn::V::zeros(out.nout(), s.problem().context().clone())
439+
} else {
440+
Eqn::V::zeros(0, s.problem().context().clone())
441+
};
442+
Ok((ret, tmp_nout))
403443
}
404444

405445
#[cfg(test)]

0 commit comments

Comments
 (0)