Skip to content

Commit 717963c

Browse files
refactor: change asserts to errors in tableau checks (#180)
* refactor: change asserts to errors in tableau checks * collapse if statement
1 parent 4352c83 commit 717963c

File tree

2 files changed

+63
-46
lines changed

2 files changed

+63
-46
lines changed

diffsol/src/error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ pub enum OdeSolverError {
100100
ProblemNotSet,
101101
#[error("Jacobian not available")]
102102
JacobianNotAvailable,
103+
#[error("Invalid Tableau: {0}")]
104+
InvalidTableau(String),
103105
#[error("Error: {0}")]
104106
Other(String),
105107
}

diffsol/src/ode_solver/runge_kutta.rs

Lines changed: 61 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -225,36 +225,45 @@ where
225225
let s = tableau.s();
226226
for i in 0..s {
227227
for j in i..s {
228-
assert_eq!(
229-
tableau.a().get_index(i, j),
230-
Eqn::T::zero(),
231-
"Invalid tableau, expected a(i, j) = 0 for i >= j"
232-
);
228+
if tableau.a().get_index(i, j) != Eqn::T::zero() {
229+
return Err(ode_solver_error!(
230+
InvalidTableau,
231+
format!(
232+
"Invalid tableau, expected a(i, j) = 0 for i >= j, but found a({}, {}) = {}",
233+
i,
234+
j,
235+
tableau.a().get_index(i, j)
236+
)
237+
));
238+
}
233239
}
234240
}
235241

236242
// check last row of a is the same as b
237243
for i in 0..s {
238-
assert_eq!(
239-
tableau.a().get_index(s - 1, i),
240-
tableau.b().get_index(i),
241-
"Invalid tableau, expected a(s-1, i) = b(i)"
242-
);
244+
if tableau.a().get_index(s - 1, i) != tableau.b().get_index(i) {
245+
return Err(ode_solver_error!(
246+
InvalidTableau,
247+
"Invalid tableau, expected a(s-1, i) = b(i)"
248+
));
249+
}
243250
}
244251

245252
// check that last c is 1
246-
assert_eq!(
247-
tableau.c().get_index(s - 1),
248-
Eqn::T::one(),
249-
"Invalid tableau, expected c(s-1) = 1"
250-
);
253+
if tableau.c().get_index(s - 1) != Eqn::T::one() {
254+
return Err(ode_solver_error!(
255+
InvalidTableau,
256+
"Invalid tableau, expected c(s-1) = 1"
257+
));
258+
}
251259

252260
// check that first c is 0
253-
assert_eq!(
254-
tableau.c().get_index(0),
255-
Eqn::T::zero(),
256-
"Invalid tableau, expected c(0) = 0"
257-
);
261+
if tableau.c().get_index(0) != Eqn::T::zero() {
262+
return Err(ode_solver_error!(
263+
InvalidTableau,
264+
"Invalid tableau, expected c(0) = 0"
265+
));
266+
}
258267
Ok(())
259268
}
260269

@@ -267,54 +276,60 @@ where
267276
let s = tableau.s();
268277
for i in 0..s {
269278
for j in (i + 1)..s {
270-
assert_eq!(
271-
tableau.a().get_index(i, j),
272-
Eqn::T::zero(),
273-
"Invalid tableau, expected a(i, j) = 0 for i > j"
274-
);
279+
if tableau.a().get_index(i, j) != Eqn::T::zero() {
280+
return Err(ode_solver_error!(
281+
InvalidTableau,
282+
"Invalid tableau, expected a(i, j) = 0 for i > j"
283+
));
284+
}
275285
}
276286
}
277287
let gamma = tableau.a().get_index(1, 1);
278288
//check that for i = 1..s-1, a(i, i) = gamma
279289
for i in 1..tableau.s() {
280-
assert_eq!(
281-
tableau.a().get_index(i, i),
282-
gamma,
283-
"Invalid tableau, expected a(i, i) = gamma = {gamma} for i = 1..s-1",
284-
);
290+
if tableau.a().get_index(i, i) != gamma {
291+
return Err(ode_solver_error!(
292+
InvalidTableau,
293+
format!("Invalid tableau, expected a(i, i) = gamma = {gamma} for i = 1..s-1")
294+
));
295+
}
285296
}
286297
// if a(0, 0) = gamma, then we're a SDIRK method
287298
// if a(0, 0) = 0, then we're a ESDIRK method
288299
// otherwise, error
289300
let zero = Eqn::T::zero();
290301
if tableau.a().get_index(0, 0) != zero && tableau.a().get_index(0, 0) != gamma {
291-
panic!("Invalid tableau, expected a(0, 0) = 0 or a(0, 0) = gamma");
302+
return Err(ode_solver_error!(
303+
InvalidTableau,
304+
"Invalid tableau, expected a(0, 0) = 0 or a(0, 0) = gamma"
305+
));
292306
}
293307
let is_sdirk = tableau.a().get_index(0, 0) == gamma;
294308

295309
// check last row of a is the same as b
296310
for i in 0..s {
297-
assert_eq!(
298-
tableau.a().get_index(s - 1, i),
299-
tableau.b().get_index(i),
300-
"Invalid tableau, expected a(s-1, i) = b(i)"
301-
);
311+
if tableau.a().get_index(s - 1, i) != tableau.b().get_index(i) {
312+
return Err(ode_solver_error!(
313+
InvalidTableau,
314+
"Invalid tableau, expected a(s-1, i) = b(i)"
315+
));
316+
}
302317
}
303318

304319
// check that last c is 1
305-
assert_eq!(
306-
tableau.c().get_index(s - 1),
307-
Eqn::T::one(),
308-
"Invalid tableau, expected c(s-1) = 1"
309-
);
320+
if tableau.c().get_index(s - 1) != Eqn::T::one() {
321+
return Err(ode_solver_error!(
322+
InvalidTableau,
323+
"Invalid tableau, expected c(s-1) = 1"
324+
));
325+
}
310326

311327
// check that the first c is 0 for esdirk methods
312-
if !is_sdirk {
313-
assert_eq!(
314-
tableau.c().get_index(0),
315-
Eqn::T::zero(),
328+
if !is_sdirk && tableau.c().get_index(0) != Eqn::T::zero() {
329+
return Err(ode_solver_error!(
330+
InvalidTableau,
316331
"Invalid tableau, expected c(0) = 0 for esdirk methods"
317-
);
332+
));
318333
}
319334
Ok(())
320335
}

0 commit comments

Comments
 (0)