Skip to content

Commit 5e7eef7

Browse files
committed
assigning J R and g
1 parent d33f9c5 commit 5e7eef7

File tree

3 files changed

+101
-7
lines changed

3 files changed

+101
-7
lines changed

modules/cuqdyn-c/include/functions.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
/// Function used to solve the ODE using cvodes
77
int ode_model_fun(sunrealtype t, N_Vector y, N_Vector ydot, void *data);
8-
/// Objetive function for the lotka_volterra problem used by the sacess library
8+
/// Objetive functions used by the sacess library
99
void *obj_func(double *x, void *data);
10+
void *obj_func2(double *x, void *data);
1011

1112
#endif // LOTKA_VOLTERRA_H

modules/cuqdyn-c/src/ess_solver.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ N_Vector execute_ess_solver(const char *file, const char *path, N_Vector texp, S
105105
}
106106
}
107107

108-
execute_Solver(exptotal, &result, obj_func);
108+
execute_Solver(exptotal, &result, obj_func2);
109109

110110
N_Vector predicted_params = New_Serial(exptotal[0].test.bench.dim);
111111
destroyexp(exptotal);

modules/cuqdyn-c/src/functions.c

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
#include <method_module/structure_paralleltestbed.h>
33
#include <nvector/nvector_serial.h>
44
#include <stdio.h>
5+
#include <stdlib.h>
56
#include <string.h>
67
#include <sunmatrix/sunmatrix_dense.h>
78

89
#include "config.h"
910
#include "cuqdyn.h"
1011
#include "ode_solver.h"
1112
#include "states_transformer.h"
13+
#include "sundials/sundials_types.h"
1214

1315
extern void eval_f_exprs(sunrealtype t, sunrealtype *y, sunrealtype *ydot, sunrealtype *params, CuqDynContext context);
1416

@@ -62,7 +64,6 @@ void *obj_func(double *x, void *data)
6264

6365
const long rows = SM_ROWS_D(result);
6466
const long cols = SM_COLUMNS_D(result);
65-
sunrealtype J = 0.0;
6667

6768
if (SM_ROWS_D(exptotal->yexp) != rows)
6869
{
@@ -78,19 +79,111 @@ void *obj_func(double *x, void *data)
7879
exit(-1);
7980
}
8081

81-
for (long i = 0; i < rows; ++i)
82+
double *J = malloc(sizeof(double));
83+
*J = 0.0;
84+
double *g = malloc(sizeof(double));
85+
*g = 0.0;
86+
double *R = malloc(sizeof(double) * rows * cols);
87+
88+
for (long j = 0; j < cols; ++j)
8289
{
83-
for (long j = 0; j < cols; ++j)
90+
for (long i = 0; i < rows; ++i)
8491
{
8592
const sunrealtype diff = SM_ELEMENT_D(result, i, j) - SM_ELEMENT_D(exptotal->yexp, i, j);
86-
J += diff * diff;
93+
R[j * rows + i] = diff;
94+
*J += diff * diff;
8795
}
8896
}
8997

90-
res->value = J;
98+
res->size_j = 1;
99+
res->J = J;
100+
res->value = *J;
101+
res->size_r = rows * cols;
102+
res->R = R;
103+
res->g = g;
91104

92105
N_VDestroy(parameters);
93106
SUNMatDestroy(result);
94107

95108
return res;
96109
}
110+
111+
void *obj_func2(double *x, void *data)
112+
{
113+
CuqdynConf *conf = get_cuqdyn_conf(get_cuqdyn_context());
114+
experiment_total *exptotal = data;
115+
output_function *res = calloc(1, sizeof(output_function));
116+
117+
N_Vector parameters = New_Serial(conf->ode_expr.p_count);
118+
memcpy(NV_DATA_S(parameters), x, NV_LENGTH_S(parameters) * sizeof(sunrealtype));
119+
120+
N_Vector texp = New_Serial(NV_LENGTH_S(exptotal->texp));
121+
memcpy(NV_DATA_S(texp), NV_DATA_S(exptotal->texp), NV_LENGTH_S(exptotal->texp) * sizeof(sunrealtype));
122+
123+
#ifdef DEBUG
124+
fprintf(stdout, "[DEBUG] [OBJ FUNC] ");
125+
fprintf(stdout, "Params: ");
126+
for (int i = 0; i < NV_LENGTH_S(parameters); i++)
127+
{
128+
fprintf(stdout, "%f ", NV_Ith_S(parameters, i));
129+
}
130+
fprintf(stdout, "\n");
131+
#endif
132+
133+
const sunrealtype t0 = NV_Ith_S(texp, 0);
134+
135+
TransposedStates ode_solution = solve_ode(parameters, exptotal->initial_values, t0, texp);
136+
ObservablesTransposedStates result = transform_states(ode_solution);
137+
138+
const long rows = SM_ROWS_D(result);
139+
const long cols = SM_COLUMNS_D(result);
140+
141+
if (SM_ROWS_D(exptotal->yexp) != rows)
142+
{
143+
fprintf(stderr, "ERROR: The yexp rows don't match the ode result rows: %ld vs %ld\n", SM_ROWS_D(exptotal->yexp),
144+
rows);
145+
exit(-1);
146+
}
147+
148+
if (SM_COLUMNS_D(exptotal->yexp) != cols)
149+
{
150+
fprintf(stderr, "ERROR: The yexp cols don't match the ode result cols: %ld vs %ld\n",
151+
SM_COLUMNS_D(exptotal->yexp), cols - 1);
152+
exit(-1);
153+
}
154+
155+
double *J = malloc(sizeof(double));
156+
*J = 0.0;
157+
double *g = malloc(sizeof(double));
158+
*g = 0.0;
159+
double *R = malloc(sizeof(double) * rows * cols);
160+
161+
for (long j = 0; j < cols; ++j)
162+
{
163+
double col_mean = 0.0;
164+
for (long i = 0; i < rows; ++i)
165+
{
166+
col_mean += SM_ELEMENT_D(exptotal->yexp, i, j);
167+
}
168+
col_mean /= rows;
169+
170+
for (long i = 0; i < rows; ++i)
171+
{
172+
const sunrealtype diff = SM_ELEMENT_D(result, i, j) - SM_ELEMENT_D(exptotal->yexp, i, j);
173+
R[j * rows + i] = diff / col_mean;
174+
*J += diff * diff;
175+
}
176+
}
177+
178+
res->size_j = 1;
179+
res->J = J;
180+
res->value = *J*100;
181+
res->size_r = rows * cols;
182+
res->R = R;
183+
res->g = g;
184+
185+
N_VDestroy(parameters);
186+
SUNMatDestroy(result);
187+
188+
return res;
189+
}

0 commit comments

Comments
 (0)