Skip to content

Commit 8175635

Browse files
committed
removed the time from the states matrices
1 parent 8e14148 commit 8175635

File tree

8 files changed

+51
-59
lines changed

8 files changed

+51
-59
lines changed

modules/cuqdyn-c/include/cuqdyn.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
#define New_Serial(n) N_VNew_Serial(n, get_sundials_ctx())
99

1010
typedef SUNMatrix States;
11-
typedef SUNMatrix ObservablesStates;
1211
/*
1312
* TransposedStates where:
1413
* - Each col corresponds to a time point
15-
* - Row 0: Time values (t)
16-
* - Rows 1-n: Solution components (y1, y2, ..., yn)
14+
* - Rows 0-n: Solution components (y1, y2, ..., yn)
1715
*/
1816
typedef SUNMatrix TransposedStates;
17+
typedef SUNMatrix ObservablesStates;
1918
typedef SUNMatrix ObservablesTransposedStates;
2019
typedef SUNMatrix ObservedData;
2120
typedef SUNMatrix TransposedObservedData;

modules/cuqdyn-c/src/cuqdyn.c

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -164,21 +164,19 @@ CuqdynResult *cuqdyn_algo(const char *data_file, const char *sacess_conf_file, c
164164
execute_ess_solver(sacess_conf_file, output_file, texp, yexp, tmp_initial_condition, initial_params);
165165

166166
// Saving the ode solution data obtained with the predicted params
167-
TransposedStates ode_solution = solve_ode(predicted_params, initial_condition, t0, times);
168-
ObservablesTransposedStates obs_states = transform_states(ode_solution);
169-
ObservablesTransposedStates predicted_data = copy_matrix_remove_rows(obs_states, create_array((long[]) {1L}, 1)); // Removing the time row
170-
SUNMatDestroy(obs_states);
167+
TransposedStates ode_solution_states = solve_ode(predicted_params, initial_condition, t0, times);
168+
ObservablesTransposedStates predicted_obs_states = transform_states(ode_solution_states);
171169

172170
for (int j = 0; j < n; ++j)
173171
{
174172
sunrealtype observed = SM_ELEMENT_D(observed_data, i, j);
175-
sunrealtype predicted = SM_ELEMENT_D(predicted_data, j, i);
173+
sunrealtype predicted = SM_ELEMENT_D(predicted_obs_states, j, i);
176174

177175
NV_Ith_S(residuals, j) = fabs(observed - predicted);
178176
}
179177

180178
#ifdef MPI
181-
long predicted_data_len = SM_COLUMNS_D(predicted_data) * SM_ROWS_D(predicted_data);
179+
long predicted_data_len = SM_COLUMNS_D(predicted_obs_states) * SM_ROWS_D(predicted_obs_states);
182180
if (rank != 0)
183181
{
184182
// Sending
@@ -196,7 +194,7 @@ CuqdynResult *cuqdyn_algo(const char *data_file, const char *sacess_conf_file, c
196194
{
197195
#endif
198196
set_matrix_row(resid_loo, residuals, i, 0, NV_LENGTH_S(residuals));
199-
matrix_array_set_index(media_matrix, i - 1, predicted_data);
197+
matrix_array_set_index(media_matrix, i - 1, predicted_obs_states);
200198
set_matrix_row(predicted_params_matrix, predicted_params, i, 0, NV_LENGTH_S(predicted_params));
201199
#ifdef MPI
202200
// Receiving
@@ -225,7 +223,7 @@ CuqdynResult *cuqdyn_algo(const char *data_file, const char *sacess_conf_file, c
225223
}
226224

227225
N_VDestroy(predicted_params);
228-
SUNMatDestroy(predicted_data);
226+
SUNMatDestroy(predicted_obs_states);
229227
}
230228

231229
N_VDestroy(residuals);

modules/cuqdyn-c/src/ess_solver.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
#include <method_module/structure_paralleltestbed.h>
55
#include "method_module/solversinterface.h"
6-
#include "output/output.h"
76

87
#include "ess_solver.h"
98

modules/cuqdyn-c/src/functions.c

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,17 @@ void *obj_func(double *x, void *data)
5252
}
5353

5454
// We compare rows with cols because the result matrix is transposed
55-
if (SM_COLUMNS_D(exptotal->yexp) != rows - 1)
55+
if (SM_COLUMNS_D(exptotal->yexp) != rows)
5656
{
5757
fprintf(stderr, "ERROR: The yexp cols don't match the ode result cols: %ld vs %ld\n", SM_COLUMNS_D(exptotal->yexp), cols - 1);
5858
exit(-1);
5959
}
6060

61-
// Note that the first row of the result matrix is t
62-
for (long i = 1; i < rows; ++i)
61+
for (long i = 0; i < rows; ++i)
6362
{
6463
for (long j = 0; j < cols; ++j)
6564
{
66-
const sunrealtype diff = SM_ELEMENT_D(result, i, j) - SM_ELEMENT_D(exptotal->yexp, j, i - 1);
65+
const sunrealtype diff = SM_ELEMENT_D(result, i, j) - SM_ELEMENT_D(exptotal->yexp, j, i);
6766
J += diff * diff;
6867
}
6968
}

modules/cuqdyn-c/src/ode_solver.c

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ TransposedStates solve_ode(N_Vector parameters, N_Vector initial_values, sunreal
4444
sunrealtype t;
4545

4646
N_Vector yout = New_Serial(NV_LENGTH_S(initial_values));
47-
int result_rows = cuqdyn_conf->ode_expr.y_count + 1; // We add the time col
47+
int result_rows = cuqdyn_conf->ode_expr.y_count;
4848
TransposedStates result = NewDenseMatrix(result_rows, NV_LENGTH_S(times));
4949

5050
for (int i = 0; i < NV_LENGTH_S(times); ++i)
@@ -53,8 +53,7 @@ TransposedStates solve_ode(N_Vector parameters, N_Vector initial_values, sunreal
5353

5454
if (actual_time == t0)
5555
{
56-
SM_COLUMN_D(result, i)[0] = t0;
57-
memcpy(&SM_COLUMN_D(result, i)[1], NV_DATA_S(initial_values), NV_LENGTH_S(initial_values) * sizeof(sunrealtype));
56+
memcpy(SM_COLUMN_D(result, i), NV_DATA_S(initial_values), NV_LENGTH_S(initial_values) * sizeof(sunrealtype));
5857
continue;
5958
}
6059

@@ -65,8 +64,7 @@ TransposedStates solve_ode(N_Vector parameters, N_Vector initial_values, sunreal
6564
return NULL;
6665
}
6766

68-
SM_COLUMN_D(result, i)[0] = t;
69-
memcpy(&SM_COLUMN_D(result, i)[1], NV_DATA_S(yout), NV_LENGTH_S(yout) * sizeof(sunrealtype));
67+
memcpy(SM_COLUMN_D(result, i), NV_DATA_S(yout), NV_LENGTH_S(yout) * sizeof(sunrealtype));
7068
}
7169

7270
N_VDestroy(yout);

modules/cuqdyn-c/src/states_transformer.c

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,12 @@ ObservablesTransposedStates transform_states(TransposedStates transposed_states)
2323
const int rows = SM_ROWS_D(transposed_states);
2424
const int cols = SM_COLUMNS_D(transposed_states);
2525

26-
ObservablesTransposedStates transposed_result = NewDenseMatrix(conf->states_transformer.count + 1, cols);
26+
ObservablesTransposedStates transposed_result = NewDenseMatrix(conf->states_transformer.count, cols);
2727

2828
for (int i = 0; i < cols; ++i)
2929
{
30-
// Copy the time point to the first element of the transformed state
31-
SM_COLUMN_D(transposed_result, i)[0] = SM_COLUMN_D(transposed_states, i)[0];
32-
33-
// The first element is the time point
34-
sunrealtype *input = &SM_COLUMN_D(transposed_states, i)[1];
35-
sunrealtype *dest = &SM_COLUMN_D(transposed_result, i)[1];
30+
sunrealtype *input = SM_COLUMN_D(transposed_states, i);
31+
sunrealtype *dest = SM_COLUMN_D(transposed_result, i);
3632

3733
eval_states_transformer_expr(input, dest, get_cuqdyn_context());
3834
}

tests/test_ode_solver.c

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,12 @@ void test_lotka_volterra()
6262
const int cols = SM_COLUMNS_D(result);
6363

6464
assert(cols == 9);
65-
assert(rows == 3);
65+
assert(rows == 2);
6666

67-
assert(fabs(SM_ELEMENT_D(result, 0, 0) - 1.0) < 0.0001);
68-
assert(fabs(SM_ELEMENT_D(result, 1, 0) - 15.10) < 0.01);
69-
assert(fabs(SM_ELEMENT_D(result, 2, 0) - 3.883) < 0.001);
70-
assert(fabs(SM_ELEMENT_D(result, 1, 6) - 53.79) < 0.01);
71-
assert(fabs(SM_ELEMENT_D(result, 2, 6) - 5.456) < 0.001);
67+
assert(fabs(SM_ELEMENT_D(result, 0, 0) - 15.10) < 0.01);
68+
assert(fabs(SM_ELEMENT_D(result, 1, 0) - 3.883) < 0.001);
69+
assert(fabs(SM_ELEMENT_D(result, 0, 6) - 53.79) < 0.01);
70+
assert(fabs(SM_ELEMENT_D(result, 1, 6) - 5.456) < 0.001);
7271

7372
destroy_cuqdyn_context(context);
7473
SUNMatDestroy(result);
@@ -114,13 +113,12 @@ void test_alpha_pienene()
114113
const int cols = SM_COLUMNS_D(result);
115114

116115
assert(cols == 9);
117-
assert(rows == 6);
116+
assert(rows == 5);
118117

119-
assert(fabs(SM_ELEMENT_D(result, 0, 0) - 0.0) < 0.0001);
120-
assert(fabs(SM_ELEMENT_D(result, 1, 0) - 100) < 0.01);
121-
assert(fabs(SM_ELEMENT_D(result, 3, 0) - 0) < 0.001);
122-
assert(fabs(SM_ELEMENT_D(result, 1, 6) - 2.510e+01) < 2);
123-
assert(fabs(SM_ELEMENT_D(result, 2, 6) - 4.814e+01) < 2);
118+
assert(fabs(SM_ELEMENT_D(result, 0, 0) - 100) < 0.01);
119+
assert(fabs(SM_ELEMENT_D(result, 2, 0) - 0) < 0.001);
120+
assert(fabs(SM_ELEMENT_D(result, 0, 6) - 2.510e+01) < 2);
121+
assert(fabs(SM_ELEMENT_D(result, 1, 6) - 4.814e+01) < 2);
124122

125123
destroy_cuqdyn_context(context);
126124
SUNMatDestroy(result);
@@ -161,12 +159,11 @@ void test_logistic_model()
161159
const int cols = SM_COLUMNS_D(result);
162160

163161
assert(cols == 11);
164-
assert(rows == 2);
162+
assert(rows == 1);
165163

166-
assert(fabs(SM_ELEMENT_D(result, 0, 0) - 0.0) < 0.0001);
167-
assert(fabs(SM_ELEMENT_D(result, 1, 0) - 10) < 0.01);
168-
assert(fabs(SM_ELEMENT_D(result, 1, 5) - 9.428e+01) < 0.01);
169-
assert(fabs(SM_ELEMENT_D(result, 1, 6) - 9.782e+01) < 0.01);
164+
assert(fabs(SM_ELEMENT_D(result, 0, 0) - 10) < 0.01);
165+
assert(fabs(SM_ELEMENT_D(result, 0, 5) - 9.428e+01) < 0.01);
166+
assert(fabs(SM_ELEMENT_D(result, 0, 6) - 9.782e+01) < 0.01);
170167

171168
destroy_cuqdyn_context(context);
172169
SUNMatDestroy(result);

tests/test_transform_states.c

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,33 +12,39 @@ int main()
1212
CuqdynConf *conf = get_cuqdyn_conf(context);
1313

1414
sunrealtype states_values[] = {
15-
0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
16-
1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2
15+
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
16+
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2
1717
};
1818
sunrealtype expected_transformation_values[] = {
19-
0, 1, 2, 1, 3, 1, 1,
20-
1, 2, 4, 2, 6, 2, 2
19+
1, 2, 1, 3, 1, 1,
20+
2, 4, 2, 6, 2, 2
2121
};
2222

2323
assert(6 == conf->states_transformer.count);
2424

25-
TransposedStates states = NewDenseMatrix(16, 2);
25+
TransposedStates states = NewDenseMatrix(15, 2);
2626

27-
for (int i = 0; i < 2; ++i) {
28-
for (int j = 0; j < 16; ++j) {
27+
for (int i = 0; i < 2; ++i)
28+
{
29+
for (int j = 0; j < 15; ++j)
30+
{
2931
SM_ELEMENT_D(states, j, i) = states_values[16 * i + j];
3032
}
3133
}
3234

3335
ObservablesTransposedStates transformation = transform_states(states);
3436

35-
assert(SM_ROWS_D(transformation) == 7);
36-
assert(SM_COLUMNS_D(transformation) == 2);
37-
38-
for (int i = 0; i < 2; ++i) {
39-
for (int j = 0; j < 7; ++j) {
40-
41-
assert(expected_transformation_values[7 * i + j] == SM_ELEMENT_D(transformation, j, i));
37+
int rows = SM_ROWS_D(transformation);
38+
int columns = SM_COLUMNS_D(transformation);
39+
40+
assert(rows == 6);
41+
assert(columns == 2);
42+
43+
for (int i = 0; i < columns; ++i)
44+
{
45+
for (int j = 0; j < rows; ++j)
46+
{
47+
assert(expected_transformation_values[rows * i + j] == SM_ELEMENT_D(transformation, j, i));
4248
}
4349
}
4450

0 commit comments

Comments
 (0)