Skip to content

Commit fbc511f

Browse files
committed
found and solved the bug in the tests
1 parent 4a16293 commit fbc511f

File tree

4 files changed

+79
-22
lines changed

4 files changed

+79
-22
lines changed

deps.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ set(RUST_LIB_DIR "${PROJECT_SOURCE_DIR}/modules/mexpreval")
8282
set(RUST_TARGET_DIR "${RUST_LIB_DIR}/target/release")
8383

8484
add_custom_target(rust_lib ALL
85+
env RUSTFLAGS=-C target-cpu=native
8586
COMMAND cargo build --release
8687
WORKING_DIRECTORY ${RUST_LIB_DIR}
8788
)

modules/mexpreval/benches/eval_f_exprs_bench.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ fn lotka_volterra_bench_eval(c: &mut Criterion) {
1212
let mut params = vec![0.1, 0.2, 0.3, 0.4];
1313

1414
let expr_strings = [
15-
CString::new("p1 * y1 - p2 * y2").unwrap(),
16-
CString::new("p3 * y2 - p4 * y1").unwrap(),
15+
CString::new("y1 * (p1 - p2 * y2)").unwrap(),
16+
CString::new("-y2 * (p3 - p4 * y1)").unwrap(),
1717
];
1818
let expr_ptrs: Vec<*const c_char> = expr_strings.iter().map(|s| s.as_ptr()).collect();
1919

@@ -31,6 +31,7 @@ fn lotka_volterra_bench_eval(c: &mut Criterion) {
3131
});
3232
}
3333

34+
#[allow(dead_code)]
3435
fn lotka_volterra_predefined_bench_eval(c: &mut Criterion) {
3536
let num_exprs = 2;
3637
let num_params = 4;
@@ -62,9 +63,9 @@ fn logistic_model_bench_eval(c: &mut Criterion) {
6263
let num_exprs = 1;
6364
let num_params = 2;
6465

65-
let mut y = vec![1.0];
66+
let mut y = vec![0.0];
6667
let mut ydot = vec![0.0; num_exprs];
67-
let mut params = vec![0.1, 0.2];
68+
let mut params = vec![0.1, 100.0];
6869

6970
let expr_strings = [
7071
CString::new("p1 * y1 * (1 - y1 / p2)").unwrap()
@@ -116,5 +117,5 @@ fn alpha_pinene_bench_eval(c: &mut Criterion) {
116117
});
117118
}
118119

119-
criterion_group!(benches, lotka_volterra_bench_eval, logistic_model_bench_eval, alpha_pinene_bench_eval, lotka_volterra_predefined_bench_eval);
120+
criterion_group!(benches, lotka_volterra_bench_eval, logistic_model_bench_eval, alpha_pinene_bench_eval);
120121
criterion_main!(benches);

modules/mexpreval/src/lib.rs

Lines changed: 66 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#![allow(static_mut_refs)]
2+
#![allow(invalid_reference_casting)]
23
mod models;
34

4-
use meval::{Context, Expr};
5-
use std::cell::OnceCell;
5+
use meval::{Context, ContextProvider, Expr};
66
use std::str::FromStr;
77
use std::sync::LazyLock;
88
use std::{env, ffi::CStr, os::raw::c_char, slice};
@@ -36,10 +36,10 @@ impl OdeExpr {
3636
}
3737

3838
static mut EXPRS: Vec<Expr> = Vec::new();
39-
static mut CONTEXT: OnceCell<Context<'static>> = OnceCell::new();
39+
static mut CONTEXT: Option<Context<'static>> = None;
4040
static mut Y: Vec<String> = Vec::new();
4141
static mut P: Vec<String> = Vec::new();
42-
static mut ODE_EXPR: OnceCell<OdeExpr> = OnceCell::new();
42+
static mut ODE_EXPR: Option<OdeExpr> = None;
4343
static DEF_YDOT: LazyLock<f64> = LazyLock::new(|| {
4444
env::var("CUQDYN_DEF_YDOT")
4545
.unwrap_or_else(|_| "0.0".to_string())
@@ -50,11 +50,13 @@ static DEF_YDOT: LazyLock<f64> = LazyLock::new(|| {
5050
#[allow(clippy::missing_safety_doc)]
5151
#[no_mangle]
5252
pub unsafe extern "C" fn mexpreval_init(ode_expr: OdeExpr) {
53+
Y.clear();
5354
Y.extend(
5455
(0..ode_expr.y_count)
5556
.map(|i| format!("y{}", i + 1))
5657
.collect::<Vec<String>>(),
5758
);
59+
P.clear();
5860
P.extend(
5961
(0..ode_expr.p_count)
6062
.map(|i| format!("p{}", i + 1))
@@ -63,6 +65,7 @@ pub unsafe extern "C" fn mexpreval_init(ode_expr: OdeExpr) {
6365

6466
let exprs: &[*const c_char] = slice::from_raw_parts(ode_expr.exprs, ode_expr.y_count as usize);
6567

68+
EXPRS.clear();
6669
for ptr in exprs.iter() {
6770
let c_str = CStr::from_ptr(*ptr);
6871
let s = c_str.to_str().unwrap();
@@ -71,15 +74,15 @@ pub unsafe extern "C" fn mexpreval_init(ode_expr: OdeExpr) {
7174

7275
EXPRS.push(expr);
7376
}
74-
75-
let _ = CONTEXT.set(Context::new());
76-
let _ = ODE_EXPR.set(ode_expr);
77+
78+
CONTEXT = Some(Context::new());
79+
ODE_EXPR = Some(ode_expr);
7780
}
7881

7982
#[allow(clippy::missing_safety_doc)]
8083
#[no_mangle]
8184
pub unsafe extern "C" fn eval_f_exprs(_t: f64, y: *mut f64, ydot: *mut f64, params: *mut f64) {
82-
let ctx = CONTEXT.get_mut().unwrap_unchecked();
85+
let ctx = CONTEXT.as_mut().unwrap();
8386

8487
let y: &[f64] = slice::from_raw_parts(y, Y.len());
8588
let ydot: &mut [f64] = slice::from_raw_parts_mut(ydot, Y.len());
@@ -94,9 +97,8 @@ pub unsafe extern "C" fn eval_f_exprs(_t: f64, y: *mut f64, ydot: *mut f64, para
9497
}
9598

9699
for (i, expr) in EXPRS.iter().enumerate() {
97-
ydot[i] = expr
98-
.eval_with_context(CONTEXT.get().unwrap_unchecked())
99-
.unwrap_unchecked();
100+
let ctx = CONTEXT.as_ref().unwrap();
101+
ydot[i] = expr.eval_with_context(ctx).unwrap();
100102

101103
if !ydot[i].is_finite() {
102104
eprintln!(
@@ -117,6 +119,59 @@ mod test {
117119
use std::ffi::CString;
118120
use std::os::raw::c_char;
119121

122+
#[test]
123+
fn lotka_volterra_test() {
124+
let num_exprs = 2;
125+
let num_params = 4;
126+
127+
let mut y = vec![1.0, 1.0];
128+
let mut ydot = vec![0.0; num_exprs];
129+
let mut params = vec![1.0, 2.0, 3.0, 4.0];
130+
131+
let expr_strings = [
132+
CString::new("y1 * (p1 - p2 * y2)").unwrap(),
133+
CString::new("-y2 * (p3 - p4 * y1)").unwrap(),
134+
];
135+
let expr_ptrs: Vec<*const c_char> = expr_strings.iter().map(|s| s.as_ptr()).collect();
136+
137+
unsafe {
138+
mexpreval_init(OdeExpr::new(
139+
num_exprs as i32,
140+
num_params,
141+
expr_ptrs.as_ptr(),
142+
));
143+
}
144+
145+
unsafe { eval_f_exprs(0.0, y.as_mut_ptr(), ydot.as_mut_ptr(), params.as_mut_ptr()) }
146+
147+
assert_eq!(ydot[0], -1.0);
148+
assert_eq!(ydot[1], 1.0);
149+
}
150+
151+
#[test]
152+
fn logistic_model_test() {
153+
let num_exprs = 1;
154+
let num_params = 2;
155+
156+
let mut y = vec![1.0];
157+
let mut ydot = vec![0.0; num_exprs];
158+
let mut params = vec![0.1, 100.0];
159+
160+
let expr_strings = [CString::new("p1 * y1 * (1 - y1 / p2)").unwrap()];
161+
let expr_ptrs: Vec<*const c_char> = expr_strings.iter().map(|s| s.as_ptr()).collect();
162+
163+
unsafe {
164+
mexpreval_init(OdeExpr::new(
165+
num_exprs as i32,
166+
num_params,
167+
expr_ptrs.as_ptr(),
168+
));
169+
}
170+
171+
unsafe { eval_f_exprs(0.0, y.as_mut_ptr(), ydot.as_mut_ptr(), params.as_mut_ptr()) }
172+
assert_eq!(ydot[0], 0.99)
173+
}
174+
120175
#[test]
121176
fn bench() {
122177
let num_exprs = 1;

tests/test_ess_solver.c

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,9 @@ int main(int argc, char **argv)
2323
return 0;
2424
#endif
2525

26-
logistic_model_ess(LOGISTIC_MODEL_CONF_FILE_NL2SOL_DN2FB);
27-
printf("\tTest 5 passed Logistic Model NL2SOL_DN2FB\n");
28-
2926
lotka_volterra_ess(LOTKA_VOLTERRA_CONF_FILE_NL2SOL_DN2GB);
3027
printf("\tTest 1 passed NL2SOL_DN2GB\n");
3128

32-
alpha_pinene_ess(ALPHA_PINENE_CONF_FILE_NL2SOL_DN2FB);
33-
printf("\tTest 6 passed Alpha-Pinene NL2SOL_DN2GB\n");
34-
3529
lotka_volterra_ess(LOTKA_VOLTERRA_CONF_FILE_NL2SOL_DN2FB);
3630
printf("\tTest 2 passed NL2SOL_DN2FB\n");
3731

@@ -41,6 +35,12 @@ int main(int argc, char **argv)
4135
lotka_volterra_ess(LOTKA_VOLTERRA_CONF_FILE_MISQP);
4236
printf("\tTest 4 passed MISQP\n");
4337

38+
logistic_model_ess(LOGISTIC_MODEL_CONF_FILE_NL2SOL_DN2FB);
39+
printf("\tTest 5 passed Logistic Model NL2SOL_DN2FB\n");
40+
41+
alpha_pinene_ess(ALPHA_PINENE_CONF_FILE_NL2SOL_DN2FB);
42+
printf("\tTest 6 passed Alpha-Pinene NL2SOL_DN2GB\n");
43+
4444
return 0;
4545
}
4646

0 commit comments

Comments
 (0)