66
77#include "config.h"
88#include "cuqdyn.h"
9+ #include "matlab.h"
910
1011extern void mexpreval_init (CuqdynConf cuqdyn_conf );
1112
@@ -28,6 +29,8 @@ int ode_model_fun(sunrealtype t, N_Vector y, N_Vector ydot, void *user_data)
2829 return 0 ;
2930}
3031
32+ extern void eval_states_transformer_expr (sunrealtype * input , sunrealtype * output );
33+
3134/*
3235 * function [J,g,R]=prob_mod_lv(x,texp,yexp)
3336 * [tout,yout] =
@@ -54,10 +57,35 @@ void *ode_model_obj_func(double *x, void *data)
5457
5558 SUNMatrix result = solve_ode (parameters , exptotal -> initial_values , t0 , texp );
5659
57- // Objective function code:
5860 const int rows = SM_ROWS_D (result );
5961 const int cols = SM_COLUMNS_D (result );
6062
63+ // TODO: Optimize using transposed matrix
64+ if (conf -> states_transformer .count > 0 )
65+ {
66+ SUNMatrix transformed_result = NewDenseMatrix (rows , conf -> states_transformer .count + 1 ); // + 1 adds the time column
67+
68+ for (int i = 0 ; i < rows ; ++ i )
69+ {
70+ N_Vector input = copy_matrix_row (result , i , 1 , cols );
71+ sunrealtype * output = malloc (conf -> states_transformer .count * sizeof (sunrealtype ));
72+
73+ eval_states_transformer_expr (NV_DATA_S (input ), output );
74+
75+ SM_ELEMENT_D (transformed_result , i , 0 ) = SM_ELEMENT_D (result , i , 0 );
76+ for (int j = 0 ; j < conf -> states_transformer .count ; ++ j )
77+ {
78+ SM_ELEMENT_D (transformed_result , i , j + 1 ) = NV_Ith_S (input , j );
79+ }
80+
81+ N_VDestroy (input );
82+ free (output );
83+ }
84+
85+ SUNMatDestroy (result );
86+ result = transformed_result ;
87+ }
88+
6189 sunrealtype J = 0.0 ;
6290
6391 if (SM_ROWS_D (exptotal -> yexp ) != rows )
0 commit comments