Skip to content

Commit 2ce1875

Browse files
committed
new states_transformer struct added to the cuqdyn config and updated the rust fragment to be able to use it
1 parent 435d84e commit 2ce1875

File tree

12 files changed

+487
-184
lines changed

12 files changed

+487
-184
lines changed

example-files/nfkb_cuqdyn_config.xml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,12 @@
2222
p1 * y11 * y7 - p24 * p22 * y14
2323
p28 + p27 * y7 - p29 * y15
2424
</ode_expr>
25+
<states_transformer count="6">
26+
y7
27+
y10 + y13
28+
y9
29+
y1 + y2 + y3
30+
y2
31+
y12
32+
</states_transformer>
2533
</cuqdyn-config>

modules/cuqdyn-c/include/config.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
typedef struct
77
{
8-
sunrealtype rtol;
9-
N_Vector atol;
8+
double rtol;
9+
int atol_len;
10+
double *atol;
1011

1112
} Tolerances;
1213

@@ -23,15 +24,25 @@ typedef struct
2324
OdeExpr create_ode_expr(int y_count, int p_count, char** exprs);
2425
void destroy_ode_expr(OdeExpr ode_expr);
2526

27+
typedef struct
28+
{
29+
int count;
30+
char** exprs;
31+
} StatesTransformer;
32+
33+
StatesTransformer create_states_transformer(int count, char** exprs);
34+
void destroy_states_transformer(StatesTransformer observables);
35+
2636
typedef struct
2737
{
2838
Tolerances tolerances;
2939
OdeExpr ode_expr;
40+
StatesTransformer states_transformer;
3041
} CuqdynConf;
3142

3243
CuqdynConf *init_cuqdyn_conf_from_file(const char *filename);
3344
int parse_cuqdyn_conf(const char* filename, CuqdynConf* config);
34-
CuqdynConf *init_cuqdyn_conf(Tolerances tolerances, OdeExpr ode_expr);
45+
CuqdynConf *init_cuqdyn_conf(Tolerances tolerances, OdeExpr ode_expr, StatesTransformer observables);
3546
void destroy_cuqdyn_conf();
3647
CuqdynConf * get_cuqdyn_conf();
3748

modules/cuqdyn-c/include/functions.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
#include "config.h"
77

8-
void mexpreval_init_wrapper(OdeExpr ode_expr);
8+
void mexpreval_init_wrapper(CuqdynConf cuqdyn_conf);
99
/// Function used to solve the ODE using cvodes
1010
int ode_model_fun(sunrealtype t, N_Vector y, N_Vector ydot, void *data);
1111
/// Objetive function for the lotka_volterra problem used by the sacess library

modules/cuqdyn-c/src/config.c

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313

1414
static CuqdynConf *config = NULL;
1515

16-
CuqdynConf *create_cuqdyn_conf(Tolerances tolerances, OdeExpr ode_expr)
16+
CuqdynConf *create_cuqdyn_conf(Tolerances tolerances, OdeExpr ode_expr, StatesTransformer states_transformer)
1717
{
1818
CuqdynConf *cuqdyn_conf = malloc(sizeof(CuqdynConf));
1919
cuqdyn_conf->tolerances = tolerances;
2020
cuqdyn_conf->ode_expr = ode_expr;
21+
cuqdyn_conf->states_transformer = states_transformer;
2122
return cuqdyn_conf;
2223
}
2324

@@ -30,6 +31,7 @@ void destroy_cuqdyn_conf()
3031

3132
destroy_tolerances(config->tolerances);
3233
destroy_ode_expr(config->ode_expr);
34+
destroy_states_transformer(config->states_transformer);
3335
free(config);
3436
config = NULL;
3537
}
@@ -39,8 +41,7 @@ CuqdynConf *init_cuqdyn_conf_from_file(const char *filename)
3941
CuqdynConf *tmp_config = malloc(sizeof(CuqdynConf));
4042
if (tmp_config == NULL)
4143
{
42-
fprintf(stderr, "ERROR: Memory allocation failed in function "
43-
"init_cuqdyn_conf_from_file()\n");
44+
fprintf(stderr, "ERROR: Memory allocation failed in function init_cuqdyn_conf_from_file()\n");
4445
exit(1);
4546
}
4647

@@ -53,7 +54,7 @@ CuqdynConf *init_cuqdyn_conf_from_file(const char *filename)
5354
}
5455

5556
config = tmp_config;
56-
mexpreval_init_wrapper(config->ode_expr);
57+
mexpreval_init_wrapper(*config);
5758
return config;
5859
}
5960

@@ -71,10 +72,14 @@ int parse_cuqdyn_conf(const char *filename, CuqdynConf *config)
7172

7273
sunrealtype rtol = 1e-6;
7374
N_Vector atol = NULL;
75+
7476
int y_count = 0;
7577
char **odes = NULL;
7678
int p_count = 0;
7779

80+
int obs_count = 0;
81+
char **obs_tranformations = NULL;
82+
7883
for (; cur; cur = cur->next)
7984
{
8085
if (cur->type != XML_ELEMENT_NODE)
@@ -185,19 +190,61 @@ int parse_cuqdyn_conf(const char *filename, CuqdynConf *config)
185190
token = strtok(NULL, "\n");
186191
}
187192
}
193+
else if (!xmlStrcmp(cur->name, "states_transformer"))
194+
{
195+
xmlChar *count_attr = xmlGetProp(cur, "count");
196+
if (count_attr == NULL)
197+
{
198+
fprintf(stderr, "Error: <states_transformer> node does not have a 'count' attribute\n");
199+
xmlFreeDoc(doc);
200+
return 1;
201+
}
202+
203+
obs_count = atoi((char *) count_attr);
204+
xmlFree(count_attr);
205+
206+
xmlChar *content = xmlNodeGetContent(cur);
207+
if (content == NULL)
208+
{
209+
fprintf(stderr, "Error: no content in <states_transformer> node\n");
210+
xmlFreeDoc(doc);
211+
return 1;
212+
}
213+
214+
char *str = (char *) content;
215+
216+
char *token = strtok(str, "\n");
217+
218+
obs_tranformations = calloc(y_count, sizeof(char *));
219+
220+
int index = 0;
221+
while (token != NULL)
222+
{
223+
while (*token == ' ')
224+
token++;
225+
226+
if (*token != '\0')
227+
{
228+
obs_tranformations[index] = token;
229+
index++;
230+
}
231+
token = strtok(NULL, "\n");
232+
}
233+
}
188234
}
189235

190236
config->tolerances = create_tolerances(rtol, atol);
191237
config->ode_expr = create_ode_expr(y_count, p_count, odes);
238+
config->states_transformer = create_states_transformer(obs_count, obs_tranformations);
192239
xmlFreeDoc(doc);
193240
return 0;
194241
}
195242

196-
CuqdynConf *init_cuqdyn_conf(Tolerances tolerances, OdeExpr ode_expr)
243+
CuqdynConf *init_cuqdyn_conf(Tolerances tolerances, OdeExpr ode_expr, StatesTransformer states_transformer)
197244
{
198245
destroy_cuqdyn_conf();
199246

200-
config = create_cuqdyn_conf(tolerances, ode_expr);
247+
config = create_cuqdyn_conf(tolerances, ode_expr, states_transformer);
201248
return config;
202249
}
203250

@@ -217,11 +264,13 @@ Tolerances create_tolerances(sunrealtype scalar_rtol, N_Vector atol)
217264
{
218265
Tolerances tolerances;
219266
tolerances.rtol = scalar_rtol;
220-
tolerances.atol = atol;
267+
tolerances.atol_len = NV_LENGTH_S(atol);
268+
tolerances.atol = malloc(tolerances.atol_len * sizeof(double));
269+
memcpy(tolerances.atol, NV_DATA_S(atol), tolerances.atol_len * sizeof(double));
221270
return tolerances;
222271
}
223272

224-
void destroy_tolerances(Tolerances tolerances) { N_VDestroy(tolerances.atol); }
273+
void destroy_tolerances(Tolerances tolerances) { free(tolerances.atol); }
225274

226275
OdeExpr create_ode_expr(int y_count, int p_count, char **exprs)
227276
{
@@ -233,3 +282,16 @@ OdeExpr create_ode_expr(int y_count, int p_count, char **exprs)
233282
}
234283

235284
void destroy_ode_expr(OdeExpr ode_expr) {}
285+
286+
StatesTransformer create_states_transformer(int count, char** exprs)
287+
{
288+
StatesTransformer states_transformer;
289+
states_transformer.count = count;
290+
states_transformer.exprs = exprs;
291+
return states_transformer;
292+
}
293+
294+
void destroy_states_transformer(StatesTransformer states_transformer)
295+
{
296+
297+
}

modules/cuqdyn-c/src/functions.c

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
#include "config.h"
88
#include "cuqdyn.h"
99

10-
extern void mexpreval_init(OdeExpr ode_expr);
10+
extern void mexpreval_init(CuqdynConf cuqdyn_conf);
1111

12-
void mexpreval_init_wrapper(OdeExpr ode_expr)
12+
void mexpreval_init_wrapper(CuqdynConf cuqdyn_conf)
1313
{
14-
mexpreval_init(ode_expr);
14+
mexpreval_init(cuqdyn_conf);
1515
}
1616

1717
extern void eval_f_exprs(sunrealtype t, sunrealtype *y, sunrealtype *ydot, sunrealtype *params);
@@ -60,6 +60,18 @@ void *ode_model_obj_func(double *x, void *data)
6060

6161
sunrealtype J = 0.0;
6262

63+
if (SM_ROWS_D(exptotal->yexp) != rows)
64+
{
65+
fprintf(stderr, "ERROR: The yexp rows don't match the ode result rows");
66+
exit(-1);
67+
}
68+
69+
if (SM_COLUMNS_D(exptotal->yexp) != cols - 1)
70+
{
71+
fprintf(stderr, "ERROR: The yexp cols don't match the ode result cols");
72+
exit(-1);
73+
}
74+
6375
for (long i = 0; i < rows; ++i)
6476
{
6577
// Note that the first col of the result matrix is t

modules/cuqdyn-c/src/ode_solver.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ SUNMatrix solve_ode(N_Vector parameters, N_Vector initial_values, sunrealtype t0
2222
retval = CVodeInit(cvode_mem, ode_model_fun, t0, initial_values);
2323
if (check_retval(&retval, "CVodeInit", 1)) { return NULL; }
2424

25-
N_Vector cloned_abs_tol = New_Serial(NV_LENGTH_S(tolerances.atol));
26-
memcpy(NV_DATA_S(cloned_abs_tol), NV_DATA_S(tolerances.atol), NV_LENGTH_S(tolerances.atol) * sizeof(sunrealtype));
25+
N_Vector cloned_abs_tol = New_Serial(tolerances.atol_len);
26+
memcpy(NV_DATA_S(cloned_abs_tol), tolerances.atol, tolerances.atol_len * sizeof(sunrealtype));
2727

2828
// We clone the tolerances because the CVodeFree function frees the memory allocated for the abs_tol it receives
2929
retval = CVodeSVtolerances(cvode_mem, tolerances.rtol, cloned_abs_tol);

modules/mexpreval/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ edition = "2018"
55

66
[dependencies]
77
meval = { git = "https://github.com/ZocoLini/mexpreval-rs.git", tag = "v0.2.2" }
8+
getset = { version = "0.1.5" }
89

910
[dev-dependencies]
1011
criterion = { version = "0.5.1" }

modules/mexpreval/benches/eval_f_exprs_bench.rs

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,15 @@
11
use criterion::{criterion_group, criterion_main, Criterion};
2-
use mexpreval::{eval_f_exprs, mexpreval_init, OdeExpr};
3-
use std::ffi::CString;
4-
use std::os::raw::c_char;
2+
use mexpreval::config::CuqdynConfig;
3+
use mexpreval::{eval_f_exprs, mexpreval_init};
54

65
fn lotka_volterra_bench_eval(c: &mut Criterion) {
76
let num_exprs = 2;
8-
let num_params = 4;
97

108
let mut y = vec![1.0, 2.0];
119
let mut ydot = vec![0.0; num_exprs];
1210
let mut params = vec![0.1, 0.2, 0.3, 0.4];
1311

14-
let expr_strings = [
15-
CString::new("y1 * (p1 - p2 * y2)").unwrap(),
16-
CString::new("-y2 * (p3 - p4 * y1)").unwrap(),
17-
];
18-
let expr_ptrs: Vec<*const c_char> = expr_strings.iter().map(|s| s.as_ptr()).collect();
19-
20-
unsafe { mexpreval_init(OdeExpr::new(num_exprs as i32, num_params, expr_ptrs.as_ptr())); }
12+
unsafe { mexpreval_init(CuqdynConfig::lotka_volterra_expr().into()); }
2113

2214
c.bench_function("lotka_volterra", |b| {
2315
b.iter(|| unsafe {
@@ -34,18 +26,12 @@ fn lotka_volterra_bench_eval(c: &mut Criterion) {
3426
#[allow(dead_code)]
3527
fn lotka_volterra_predefined_bench_eval(c: &mut Criterion) {
3628
let num_exprs = 2;
37-
let num_params = 4;
3829

3930
let mut y = vec![1.0, 2.0];
4031
let mut ydot = vec![0.0; num_exprs];
4132
let mut params = vec![0.1, 0.2, 0.3, 0.4];
4233

43-
let expr_strings = [
44-
CString::new("lotka-volterra").unwrap(),
45-
];
46-
let expr_ptrs: Vec<*const c_char> = expr_strings.iter().map(|s| s.as_ptr()).collect();
47-
48-
unsafe { mexpreval_init(OdeExpr::new(num_exprs as i32, num_params, expr_ptrs.as_ptr())); }
34+
unsafe { mexpreval_init(CuqdynConfig::lotka_volterra().into()); }
4935

5036
c.bench_function("lotka_volterra_predefined", |b| {
5137
b.iter(|| unsafe {
@@ -61,18 +47,12 @@ fn lotka_volterra_predefined_bench_eval(c: &mut Criterion) {
6147

6248
fn logistic_model_bench_eval(c: &mut Criterion) {
6349
let num_exprs = 1;
64-
let num_params = 2;
6550

6651
let mut y = vec![0.0];
6752
let mut ydot = vec![0.0; num_exprs];
6853
let mut params = vec![0.1, 100.0];
6954

70-
let expr_strings = [
71-
CString::new("p1 * y1 * (1 - y1 / p2)").unwrap()
72-
];
73-
let expr_ptrs: Vec<*const c_char> = expr_strings.iter().map(|s| s.as_ptr()).collect();
74-
75-
unsafe { mexpreval_init(OdeExpr::new(num_exprs as i32, num_params, expr_ptrs.as_ptr())); }
55+
unsafe { mexpreval_init(CuqdynConfig::logistic_growth_expr().into()); }
7656

7757
c.bench_function("logistic_model", |b| {
7858
b.iter(|| unsafe {
@@ -88,22 +68,12 @@ fn logistic_model_bench_eval(c: &mut Criterion) {
8868

8969
fn alpha_pinene_bench_eval(c: &mut Criterion) {
9070
let num_exprs = 5;
91-
let num_params = 5;
9271

9372
let mut y = vec![1.0, 1.0, 1.0, 1.0, 1.0];
9473
let mut ydot = vec![0.0; num_exprs];
9574
let mut params = vec![0.1, 0.2, 0.2, 0.2, 0.2];
9675

97-
let expr_strings = [
98-
CString::new("-(p1 + p2) * y1").unwrap(),
99-
CString::new("p1 * y1").unwrap(),
100-
CString::new("p2 * y1 - (p3 + p4) * y3 + p5 * y5").unwrap(),
101-
CString::new("p3 * y3").unwrap(),
102-
CString::new("p4 * y3 - p5 * y5").unwrap(),
103-
];
104-
let expr_ptrs: Vec<*const c_char> = expr_strings.iter().map(|s| s.as_ptr()).collect();
105-
106-
unsafe { mexpreval_init(OdeExpr::new(num_exprs as i32, num_params, expr_ptrs.as_ptr())); }
76+
unsafe { mexpreval_init(CuqdynConfig::alpha_pinene_expr().into()); }
10777

10878
c.bench_function("alpha_pinene", |b| {
10979
b.iter(|| unsafe {

0 commit comments

Comments
 (0)