Skip to content

Commit e268a96

Browse files
feat: create heat1d diffsl problem for tsit45 (#185)
1 parent 68b066f commit e268a96

File tree

4 files changed

+181
-2
lines changed

4 files changed

+181
-2
lines changed

diffsol/benches/ode_solvers.rs

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ fn criterion_benchmark(c: &mut Criterion) {
2424
};
2525
}
2626

27+
macro_rules! bench_explicit {
28+
($name:ident, $solver:ident, $model:ident, $model_problem:ident, $matrix:ty) => {
29+
c.bench_function(stringify!($name), |b| {
30+
b.iter(|| {
31+
let (problem, soln) = $model_problem::<$matrix>(false);
32+
benchmarks::$solver::<_>(&problem, soln.solution_points.last().unwrap().t);
33+
})
34+
});
35+
};
36+
}
37+
2738
bench!(
2839
nalgebra_bdf_exponential_decay,
2940
bdf,
@@ -48,6 +59,13 @@ fn criterion_benchmark(c: &mut Criterion) {
4859
exponential_decay_problem,
4960
NalgebraMat<f64>
5061
);
62+
bench_explicit!(
63+
nalgebra_tsit45_exponential_decay,
64+
tsit45,
65+
exponential_decay,
66+
exponential_decay_problem,
67+
NalgebraMat<f64>
68+
);
5169
bench!(
5270
nalgebra_bdf_robertson,
5371
bdf,
@@ -294,30 +312,58 @@ fn criterion_benchmark(c: &mut Criterion) {
294312
20,
295313
30
296314
);
315+
297316
macro_rules! bench_diffsl_heat2d {
298-
($name:ident, $solver:ident, $linear_solver:ident, $matrix:ty, $($N:expr),+) => {
317+
($name:ident, $solver:ident, $linear_solver:ident, $model_problem:ident, $matrix:ty, $($N:expr),+) => {
299318
$(#[cfg(feature = "diffsl-llvm")]
300319
c.bench_function(concat!(stringify!($name), "_", $N), |b| {
301320
use diffsol::ode_equations::test_models::heat2d::*;
302321
use diffsol::LlvmModule;
303-
let (problem, soln) = heat2d_diffsl_problem::<$matrix, LlvmModule, $N>();
322+
let (problem, soln) = $model_problem::<$matrix, LlvmModule, $N>();
304323
b.iter(|| {
305324
benchmarks::$solver::<_, $linear_solver<_>>(&problem, soln.solution_points.last().unwrap().t)
306325
})
307326
});)+
308327
};
309328
}
329+
330+
macro_rules! bench_diffsl_heat1d {
331+
($name:ident, $solver:ident, $model_problem:ident, $matrix:ty, $($N:expr),+) => {
332+
$(#[cfg(feature = "diffsl-llvm")]
333+
c.bench_function(concat!(stringify!($name), "_", $N), |b| {
334+
use diffsol::ode_equations::test_models::heat1d::*;
335+
use diffsol::LlvmModule;
336+
let (problem, soln) = $model_problem::<$matrix, LlvmModule, $N>();
337+
b.iter(|| {
338+
benchmarks::$solver::<_>(&problem, soln.solution_points.last().unwrap().t)
339+
})
340+
});)+
341+
};
342+
}
343+
310344
bench_diffsl_heat2d!(
311345
faer_sparse_bdf_diffsl_heat2d,
312346
bdf,
313347
FaerSparseLU,
348+
heat2d_diffsl_problem,
314349
FaerSparseMat<f64>,
315350
5,
316351
10,
317352
20,
318353
30
319354
);
320355

356+
bench_diffsl_heat1d!(
357+
faer_tsit45_diffsl_heat1d,
358+
tsit45,
359+
heat1d_diffsl_problem,
360+
FaerMat<f64>,
361+
10,
362+
20,
363+
40,
364+
80
365+
);
366+
321367
macro_rules! bench_sundials {
322368
($name:ident, $solver:ident) => {
323369
#[cfg(feature = "sundials")]
@@ -434,4 +480,16 @@ mod benchmarks {
434480
let mut s = problem.tr_bdf2::<LS>().unwrap();
435481
let _y = s.solve(t);
436482
}
483+
484+
pub fn tsit45<Eqn>(problem: &OdeSolverProblem<Eqn>, t: Eqn::T)
485+
where
486+
Eqn: OdeEquationsImplicit,
487+
Eqn::M: Matrix + DefaultSolver,
488+
Eqn::V: DefaultDenseMatrix,
489+
for<'a> &'a Eqn::V: VectorRef<Eqn::V>,
490+
for<'a> &'a Eqn::M: MatrixRef<Eqn::M>,
491+
{
492+
let mut s = problem.tsit45().unwrap();
493+
let _y = s.solve(t);
494+
}
437495
}
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
#[cfg(feature = "diffsl")]
2+
use crate::{ode_solver::problem::OdeSolverSolution, Matrix, MatrixHost, OdeSolverProblem, Vector};
3+
4+
#[cfg(feature = "diffsl")]
5+
#[allow(clippy::type_complexity)]
6+
pub fn heat1d_diffsl_problem<
7+
M: MatrixHost<T = f64>,
8+
CG: crate::CodegenModuleJit + crate::CodegenModuleCompile,
9+
const MGRID: usize,
10+
>() -> (
11+
OdeSolverProblem<impl crate::OdeEquationsImplicit<M = M, V = M::V, T = M::T, C = M::C>>,
12+
OdeSolverSolution<M::V>,
13+
) {
14+
use crate::OdeBuilder;
15+
16+
let mgridp1 = MGRID + 1;
17+
let h = 1.0 / (MGRID + 2) as f64;
18+
let y0 = (0..mgridp1)
19+
.map(|i| {
20+
let x = (i + 1) as f64 * h;
21+
if x < 0.5 {
22+
2.0 * x
23+
} else {
24+
2.0 * (1.0 - x)
25+
}
26+
})
27+
.collect::<Vec<_>>();
28+
let y0_str = y0
29+
.iter()
30+
.enumerate()
31+
.map(|(i, v)| format!("({i}): {v}"))
32+
.collect::<Vec<_>>()
33+
.join(", ");
34+
let code = format!(
35+
"
36+
D {{ 1.0 }}
37+
h {{ {h} }}
38+
A_ij {{
39+
(0..{MGRID}, 1..{mgridp1}): 1.0,
40+
(0..{mgridp1}, 0..{mgridp1}): -2.0,
41+
(1..{mgridp1}, 0..{MGRID}): 1.0,
42+
}}
43+
u_i {{
44+
{y0_str}
45+
}}
46+
heat_i {{ A_ij * u_j }}
47+
F_i {{
48+
D * heat_i / (h * h)
49+
}}
50+
out_i {{ u_i }}
51+
"
52+
);
53+
let problem = OdeBuilder::<M>::new()
54+
.rtol(1e-6)
55+
.atol([1e-6])
56+
.build_from_diffsl::<CG>(code.as_str())
57+
.unwrap();
58+
let soln = soln::<M>(problem.context().clone(), MGRID, h);
59+
(problem, soln)
60+
}
61+
62+
#[cfg(feature = "diffsl")]
63+
fn soln<M: Matrix<T = f64>>(ctx: M::C, mgrid: usize, h: f64) -> OdeSolverSolution<M::V> {
64+
// we'll put rather loose tolerances here, since the initial conditions have a discontinuity
65+
let mut soln = OdeSolverSolution {
66+
solution_points: Vec::new(),
67+
sens_solution_points: None,
68+
rtol: 1e-4,
69+
atol: M::V::from_element(mgrid + 1, 1e-4, ctx.clone()),
70+
negative_time: false,
71+
};
72+
let times = (0..5).map(|i| i as f64 * 0.01 + 0.5).collect::<Vec<_>>();
73+
let data: Vec<_> = times
74+
.iter()
75+
.map(|&t| {
76+
const PI: f64 = std::f64::consts::PI;
77+
let mut ret = vec![0.0; mgrid + 1];
78+
for (i, v) in ret.iter_mut().enumerate() {
79+
let x = (i + 1) as f64 * h;
80+
*v = 0.0;
81+
for n in 1..100 {
82+
let two_n_minus_1: f64 = f64::from(2 * n - 1);
83+
*v += (two_n_minus_1 * PI * x).sin()
84+
* (-two_n_minus_1.powi(2) * PI.powi(2) * t).exp()
85+
/ two_n_minus_1.powi(2);
86+
}
87+
*v *= 8.0 / PI.powi(2);
88+
}
89+
M::V::from_vec(ret, ctx.clone())
90+
})
91+
.collect();
92+
93+
for (values, time) in data.into_iter().zip(times.into_iter()) {
94+
soln.push(values, time);
95+
}
96+
soln
97+
}

diffsol/src/ode_equations/test_models/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ pub mod exponential_decay;
33
pub mod exponential_decay_with_algebraic;
44
pub mod foodweb;
55
pub mod gaussian_decay;
6+
pub mod heat1d;
67
pub mod heat2d;
78
pub mod robertson;
89
pub mod robertson_ode;

diffsol/src/ode_solver/explicit_rk.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,29 @@ mod test {
294294
"###);
295295
}
296296

297+
#[cfg(feature = "diffsl-llvm")]
298+
#[test]
299+
fn test_tsit45_nalgebra_heat1d_diffsl() {
300+
use crate::ode_equations::test_models::heat1d::heat1d_diffsl_problem;
301+
302+
let (problem, soln) = heat1d_diffsl_problem::<M, diffsl::LlvmModule, 10>();
303+
let mut s = problem.tsit45().unwrap();
304+
test_ode_solver(&mut s, soln, None, false, false);
305+
insta::assert_yaml_snapshot!(s.get_statistics(), @r###"
306+
number_of_linear_solver_setups: 0
307+
number_of_steps: 93
308+
number_of_error_test_failures: 9
309+
number_of_nonlinear_solver_iterations: 0
310+
number_of_nonlinear_solver_fails: 0
311+
"###);
312+
insta::assert_yaml_snapshot!(problem.eqn.rhs().statistics(), @r###"
313+
number_of_calls: 0
314+
number_of_jac_muls: 0
315+
number_of_matrix_evals: 0
316+
number_of_jac_adj_muls: 0
317+
"###);
318+
}
319+
297320
#[cfg(feature = "cuda")]
298321
#[test]
299322
fn test_tsit45_cuda_exponential_decay() {

0 commit comments

Comments
 (0)