Skip to content

Commit c20f3c1

Browse files
committed
making the transformation in the obj function
1 parent 2ce1875 commit c20f3c1

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

modules/cuqdyn-c/src/functions.c

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "config.h"
88
#include "cuqdyn.h"
9+
#include "matlab.h"
910

1011
extern void mexpreval_init(CuqdynConf cuqdyn_conf);
1112

@@ -28,6 +29,8 @@ int ode_model_fun(sunrealtype t, N_Vector y, N_Vector ydot, void *user_data)
2829
return 0;
2930
}
3031

32+
extern void eval_states_transformer_expr(sunrealtype *input, sunrealtype *output);
33+
3134
/*
3235
* function [J,g,R]=prob_mod_lv(x,texp,yexp)
3336
* [tout,yout] =
@@ -54,10 +57,35 @@ void *ode_model_obj_func(double *x, void *data)
5457

5558
SUNMatrix result = solve_ode(parameters, exptotal->initial_values, t0, texp);
5659

57-
// Objective function code:
5860
const int rows = SM_ROWS_D(result);
5961
const int cols = SM_COLUMNS_D(result);
6062

63+
// TODO: Optimize using transposed matrix
64+
if (conf->states_transformer.count > 0)
65+
{
66+
SUNMatrix transformed_result = NewDenseMatrix(rows, conf->states_transformer.count + 1); // + 1 adds the time column
67+
68+
for (int i = 0; i < rows; ++i)
69+
{
70+
N_Vector input = copy_matrix_row(result, i, 1, cols);
71+
sunrealtype *output = malloc(conf->states_transformer.count * sizeof(sunrealtype));
72+
73+
eval_states_transformer_expr(NV_DATA_S(input), output);
74+
75+
SM_ELEMENT_D(transformed_result, i, 0) = SM_ELEMENT_D(result, i, 0);
76+
for (int j = 0; j < conf->states_transformer.count; ++j)
77+
{
78+
SM_ELEMENT_D(transformed_result, i, j + 1) = NV_Ith_S(input, j);
79+
}
80+
81+
N_VDestroy(input);
82+
free(output);
83+
}
84+
85+
SUNMatDestroy(result);
86+
result = transformed_result;
87+
}
88+
6189
sunrealtype J = 0.0;
6290

6391
if (SM_ROWS_D(exptotal->yexp) != rows)

modules/mexpreval/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ pub unsafe extern "C" fn eval_f_exprs(t: f64, y: *mut f64, ydot: *mut f64, param
4040

4141
#[allow(clippy::missing_safety_doc)]
4242
#[no_mangle]
43-
pub unsafe extern "C" fn eval_obs_expr(input_state_vec: *mut f64, output_state_vec: *mut f64) {
43+
pub unsafe extern "C" fn eval_states_transformer_expr(input_state_vec: *mut f64, output_state_vec: *mut f64) {
4444
let cuqdyn_conf = CUQDYN_CONF.as_ref().unwrap();
4545

4646
let input_state_slice = slice::from_raw_parts(input_state_vec, cuqdyn_conf.ode_expr().len());

0 commit comments

Comments
 (0)