22#include <method_module/structure_paralleltestbed.h>
33#include <nvector/nvector_serial.h>
44#include <stdio.h>
5+ #include <stdlib.h>
56#include <string.h>
67#include <sunmatrix/sunmatrix_dense.h>
78
89#include "config.h"
910#include "cuqdyn.h"
1011#include "ode_solver.h"
1112#include "states_transformer.h"
13+ #include "sundials/sundials_types.h"
1214
1315extern void eval_f_exprs (sunrealtype t , sunrealtype * y , sunrealtype * ydot , sunrealtype * params , CuqDynContext context );
1416
@@ -62,7 +64,6 @@ void *obj_func(double *x, void *data)
6264
6365 const long rows = SM_ROWS_D (result );
6466 const long cols = SM_COLUMNS_D (result );
65- sunrealtype J = 0.0 ;
6667
6768 if (SM_ROWS_D (exptotal -> yexp ) != rows )
6869 {
@@ -78,19 +79,111 @@ void *obj_func(double *x, void *data)
7879 exit (-1 );
7980 }
8081
81- for (long i = 0 ; i < rows ; ++ i )
82+ double * J = malloc (sizeof (double ));
83+ * J = 0.0 ;
84+ double * g = malloc (sizeof (double ));
85+ * g = 0.0 ;
86+ double * R = malloc (sizeof (double ) * rows * cols );
87+
88+ for (long j = 0 ; j < cols ; ++ j )
8289 {
83- for (long j = 0 ; j < cols ; ++ j )
90+ for (long i = 0 ; i < rows ; ++ i )
8491 {
8592 const sunrealtype diff = SM_ELEMENT_D (result , i , j ) - SM_ELEMENT_D (exptotal -> yexp , i , j );
86- J += diff * diff ;
93+ R [j * rows + i ] = diff ;
94+ * J += diff * diff ;
8795 }
8896 }
8997
90- res -> value = J ;
98+ res -> size_j = 1 ;
99+ res -> J = J ;
100+ res -> value = * J ;
101+ res -> size_r = rows * cols ;
102+ res -> R = R ;
103+ res -> g = g ;
91104
92105 N_VDestroy (parameters );
93106 SUNMatDestroy (result );
94107
95108 return res ;
96109}
110+
111+ void * obj_func2 (double * x , void * data )
112+ {
113+ CuqdynConf * conf = get_cuqdyn_conf (get_cuqdyn_context ());
114+ experiment_total * exptotal = data ;
115+ output_function * res = calloc (1 , sizeof (output_function ));
116+
117+ N_Vector parameters = New_Serial (conf -> ode_expr .p_count );
118+ memcpy (NV_DATA_S (parameters ), x , NV_LENGTH_S (parameters ) * sizeof (sunrealtype ));
119+
120+ N_Vector texp = New_Serial (NV_LENGTH_S (exptotal -> texp ));
121+ memcpy (NV_DATA_S (texp ), NV_DATA_S (exptotal -> texp ), NV_LENGTH_S (exptotal -> texp ) * sizeof (sunrealtype ));
122+
123+ #ifdef DEBUG
124+ fprintf (stdout , "[DEBUG] [OBJ FUNC] " );
125+ fprintf (stdout , "Params: " );
126+ for (int i = 0 ; i < NV_LENGTH_S (parameters ); i ++ )
127+ {
128+ fprintf (stdout , "%f " , NV_Ith_S (parameters , i ));
129+ }
130+ fprintf (stdout , "\n" );
131+ #endif
132+
133+ const sunrealtype t0 = NV_Ith_S (texp , 0 );
134+
135+ TransposedStates ode_solution = solve_ode (parameters , exptotal -> initial_values , t0 , texp );
136+ ObservablesTransposedStates result = transform_states (ode_solution );
137+
138+ const long rows = SM_ROWS_D (result );
139+ const long cols = SM_COLUMNS_D (result );
140+
141+ if (SM_ROWS_D (exptotal -> yexp ) != rows )
142+ {
143+ fprintf (stderr , "ERROR: The yexp rows don't match the ode result rows: %ld vs %ld\n" , SM_ROWS_D (exptotal -> yexp ),
144+ rows );
145+ exit (-1 );
146+ }
147+
148+ if (SM_COLUMNS_D (exptotal -> yexp ) != cols )
149+ {
150+ fprintf (stderr , "ERROR: The yexp cols don't match the ode result cols: %ld vs %ld\n" ,
151+ SM_COLUMNS_D (exptotal -> yexp ), cols - 1 );
152+ exit (-1 );
153+ }
154+
155+ double * J = malloc (sizeof (double ));
156+ * J = 0.0 ;
157+ double * g = malloc (sizeof (double ));
158+ * g = 0.0 ;
159+ double * R = malloc (sizeof (double ) * rows * cols );
160+
161+ for (long j = 0 ; j < cols ; ++ j )
162+ {
163+ double col_mean = 0.0 ;
164+ for (long i = 0 ; i < rows ; ++ i )
165+ {
166+ col_mean += SM_ELEMENT_D (exptotal -> yexp , i , j );
167+ }
168+ col_mean /= rows ;
169+
170+ for (long i = 0 ; i < rows ; ++ i )
171+ {
172+ const sunrealtype diff = SM_ELEMENT_D (result , i , j ) - SM_ELEMENT_D (exptotal -> yexp , i , j );
173+ R [j * rows + i ] = diff / col_mean ;
174+ * J += diff * diff ;
175+ }
176+ }
177+
178+ res -> size_j = 1 ;
179+ res -> J = J ;
180+ res -> value = * J * 100 ;
181+ res -> size_r = rows * cols ;
182+ res -> R = R ;
183+ res -> g = g ;
184+
185+ N_VDestroy (parameters );
186+ SUNMatDestroy (result );
187+
188+ return res ;
189+ }
0 commit comments