@@ -20,10 +20,10 @@ namespace {
2020 * Common helpers *
2121 *********************/
2222
23- using xmatrix = decltype (NodeBipartiteObs::column_features );
23+ using xmatrix = decltype (NodeBipartiteObs::variable_features );
2424using value_type = xmatrix::value_type;
2525
26- using ColumnFeatures = NodeBipartiteObs::ColumnFeatures ;
26+ using VariableFeatures = NodeBipartiteObs::VariableFeatures ;
2727using RowFeatures = NodeBipartiteObs::RowFeatures;
2828
2929value_type constexpr cste = 5 .;
@@ -34,9 +34,9 @@ SCIP_Real obj_l2_norm(SCIP* const scip) noexcept {
3434 return norm > 0 ? norm : 1 .;
3535}
3636
37- /* *****************************************
38- * Column features extraction functions *
39- ******************************************/
37+ /* ******************************************
38+ * Variable features extraction functions *
39+ ******************************************* /
4040
4141std::optional<SCIP_Real> upper_bound (SCIP* const scip, SCIP_COL* const col) noexcept {
4242 auto const ub_val = SCIPcolGetUb (col);
@@ -85,11 +85,11 @@ std::optional<SCIP_Real> avg_sol(SCIP* const scip, SCIP_VAR* const var) noexcept
8585 return {};
8686}
8787
88- std::optional<SCIP_Real> feas_frac (SCIP* const scip, SCIP_VAR* const var, SCIP_COL* const col ) noexcept {
88+ std::optional<SCIP_Real> feas_frac (SCIP* const scip, SCIP_VAR* const var) noexcept {
8989 if (SCIPvarGetType (var) == SCIP_VARTYPE_CONTINUOUS) {
9090 return {};
9191 }
92- return SCIPfeasFrac (scip, SCIPcolGetPrimsol (col ));
92+ return SCIPfeasFrac (scip, SCIPvarGetLPSol (var ));
9393}
9494
9595/* * Convert an enum to its underlying index. */
@@ -98,89 +98,89 @@ template <typename E> constexpr auto idx(E e) {
9898}
9999
100100template <typename Features>
101- void set_static_features_for_col (Features&& out, SCIP_VAR* const var, SCIP_COL* const col , value_type obj_norm) {
102- out[idx (ColumnFeatures ::objective)] = SCIPcolGetObj (col ) / obj_norm;
101+ void set_static_features_for_var (Features&& out, SCIP_VAR* const var, value_type obj_norm) {
102+ out[idx (VariableFeatures ::objective)] = SCIPvarGetObj (var ) / obj_norm;
103103 // On-hot enconding of varaible type
104- out[idx (ColumnFeatures ::is_type_binary)] = 0 .;
105- out[idx (ColumnFeatures ::is_type_integer)] = 0 .;
106- out[idx (ColumnFeatures ::is_type_implicit_integer)] = 0 .;
107- out[idx (ColumnFeatures ::is_type_continuous)] = 0 .;
104+ out[idx (VariableFeatures ::is_type_binary)] = 0 .;
105+ out[idx (VariableFeatures ::is_type_integer)] = 0 .;
106+ out[idx (VariableFeatures ::is_type_implicit_integer)] = 0 .;
107+ out[idx (VariableFeatures ::is_type_continuous)] = 0 .;
108108 switch (SCIPvarGetType (var)) {
109109 case SCIP_VARTYPE_BINARY:
110- out[idx (ColumnFeatures ::is_type_binary)] = 1 .;
110+ out[idx (VariableFeatures ::is_type_binary)] = 1 .;
111111 break ;
112112 case SCIP_VARTYPE_INTEGER:
113- out[idx (ColumnFeatures ::is_type_integer)] = 1 .;
113+ out[idx (VariableFeatures ::is_type_integer)] = 1 .;
114114 break ;
115115 case SCIP_VARTYPE_IMPLINT:
116- out[idx (ColumnFeatures ::is_type_implicit_integer)] = 1 .;
116+ out[idx (VariableFeatures ::is_type_implicit_integer)] = 1 .;
117117 break ;
118118 case SCIP_VARTYPE_CONTINUOUS:
119- out[idx (ColumnFeatures ::is_type_continuous)] = 1 .;
119+ out[idx (VariableFeatures ::is_type_continuous)] = 1 .;
120120 break ;
121121 default :
122122 assert (false ); // All enum cases must be handled
123123 }
124124}
125125
126126template <typename Features>
127- void set_dynamic_features_for_col (
127+ void set_dynamic_features_for_var (
128128 Features&& out,
129129 SCIP* const scip,
130130 SCIP_VAR* const var,
131131 SCIP_COL* const col,
132132 value_type obj_norm,
133133 value_type n_lps) {
134- out[idx (ColumnFeatures ::has_lower_bound)] = static_cast <value_type>(lower_bound (scip, col).has_value ());
135- out[idx (ColumnFeatures ::has_upper_bound)] = static_cast <value_type>(upper_bound (scip, col).has_value ());
136- out[idx (ColumnFeatures ::normed_reduced_cost)] = SCIPgetColRedcost (scip, col ) / obj_norm;
137- out[idx (ColumnFeatures ::solution_value)] = SCIPcolGetPrimsol (col );
138- out[idx (ColumnFeatures ::solution_frac)] = feas_frac (scip, var, col ).value_or (0 .);
139- out[idx (ColumnFeatures ::is_solution_at_lower_bound)] = static_cast <value_type>(is_prim_sol_at_lb (scip, col));
140- out[idx (ColumnFeatures ::is_solution_at_upper_bound)] = static_cast <value_type>(is_prim_sol_at_ub (scip, col));
141- out[idx (ColumnFeatures ::scaled_age)] = static_cast <value_type>(SCIPcolGetAge (col)) / (n_lps + cste);
142- out[idx (ColumnFeatures ::incumbent_value)] = best_sol_val (scip, var).value_or (nan);
143- out[idx (ColumnFeatures ::average_incumbent_value)] = avg_sol (scip, var).value_or (nan);
134+ out[idx (VariableFeatures ::has_lower_bound)] = static_cast <value_type>(lower_bound (scip, col).has_value ());
135+ out[idx (VariableFeatures ::has_upper_bound)] = static_cast <value_type>(upper_bound (scip, col).has_value ());
136+ out[idx (VariableFeatures ::normed_reduced_cost)] = SCIPgetVarRedcost (scip, var ) / obj_norm;
137+ out[idx (VariableFeatures ::solution_value)] = SCIPvarGetLPSol (var );
138+ out[idx (VariableFeatures ::solution_frac)] = feas_frac (scip, var).value_or (0 .);
139+ out[idx (VariableFeatures ::is_solution_at_lower_bound)] = static_cast <value_type>(is_prim_sol_at_lb (scip, col));
140+ out[idx (VariableFeatures ::is_solution_at_upper_bound)] = static_cast <value_type>(is_prim_sol_at_ub (scip, col));
141+ out[idx (VariableFeatures ::scaled_age)] = static_cast <value_type>(SCIPcolGetAge (col)) / (n_lps + cste);
142+ out[idx (VariableFeatures ::incumbent_value)] = best_sol_val (scip, var).value_or (nan);
143+ out[idx (VariableFeatures ::average_incumbent_value)] = avg_sol (scip, var).value_or (nan);
144144 // On-hot encoding
145- out[idx (ColumnFeatures ::is_basis_lower)] = 0 .;
146- out[idx (ColumnFeatures ::is_basis_basic)] = 0 .;
147- out[idx (ColumnFeatures ::is_basis_upper)] = 0 .;
148- out[idx (ColumnFeatures ::is_basis_zero)] = 0 .;
145+ out[idx (VariableFeatures ::is_basis_lower)] = 0 .;
146+ out[idx (VariableFeatures ::is_basis_basic)] = 0 .;
147+ out[idx (VariableFeatures ::is_basis_upper)] = 0 .;
148+ out[idx (VariableFeatures ::is_basis_zero)] = 0 .;
149149 switch (SCIPcolGetBasisStatus (col)) {
150150 case SCIP_BASESTAT_LOWER:
151- out[idx (ColumnFeatures ::is_basis_lower)] = 1 .;
151+ out[idx (VariableFeatures ::is_basis_lower)] = 1 .;
152152 break ;
153153 case SCIP_BASESTAT_BASIC:
154- out[idx (ColumnFeatures ::is_basis_basic)] = 1 .;
154+ out[idx (VariableFeatures ::is_basis_basic)] = 1 .;
155155 break ;
156156 case SCIP_BASESTAT_UPPER:
157- out[idx (ColumnFeatures ::is_basis_upper)] = 1 .;
157+ out[idx (VariableFeatures ::is_basis_upper)] = 1 .;
158158 break ;
159159 case SCIP_BASESTAT_ZERO:
160- out[idx (ColumnFeatures ::is_basis_zero)] = 1 .;
160+ out[idx (VariableFeatures ::is_basis_zero)] = 1 .;
161161 break ;
162162 default :
163163 assert (false ); // All enum cases must be handled
164164 }
165165}
166166
167- void set_features_for_all_cols (xmatrix& out, scip::Model& model, bool const update_static) {
167+ void set_features_for_all_vars (xmatrix& out, scip::Model& model, bool const update_static) {
168168 auto * const scip = model.get_scip_ptr ();
169169
170170 // Contant reused in every iterations
171171 auto const n_lps = static_cast <value_type>(SCIPgetNLPs (scip));
172172 auto const obj_norm = obj_l2_norm (scip);
173173
174- auto const columns = model.lp_columns ();
175- auto const n_columns = columns .size ();
176- for (std::size_t col_idx = 0 ; col_idx < n_columns ; ++col_idx ) {
177- auto * const col = columns[col_idx ];
178- auto * const var = SCIPcolGetVar (col );
179- auto features = xt::row (out, static_cast <std::ptrdiff_t >(col_idx ));
174+ auto const variables = model.variables ();
175+ auto const n_vars = variables .size ();
176+ for (std::size_t var_idx = 0 ; var_idx < n_vars ; ++var_idx ) {
177+ auto * const var = variables[var_idx ];
178+ auto * const col = SCIPvarGetCol (var );
179+ auto features = xt::row (out, static_cast <std::ptrdiff_t >(var_idx ));
180180 if (update_static) {
181- set_static_features_for_col (features, var, col , obj_norm);
181+ set_static_features_for_var (features, var, obj_norm);
182182 }
183- set_dynamic_features_for_col (features, scip, var, col, obj_norm, n_lps);
183+ set_dynamic_features_for_var (features, scip, var, col, obj_norm, n_lps);
184184 }
185185}
186186
@@ -326,7 +326,7 @@ utility::coo_matrix<value_type> extract_edge_features(scip::Model& model) {
326326 if (scip::get_unshifted_lhs (scip, row).has_value ()) {
327327 for (std::size_t k = 0 ; k < row_nnz; ++k) {
328328 indices (0 , j + k) = i;
329- indices (1 , j + k) = static_cast <std::size_t >(SCIPcolGetLPPos (row_cols[k]));
329+ indices (1 , j + k) = static_cast <std::size_t >(SCIPcolGetVarProbindex (row_cols[k]));
330330 values[j + k] = -row_vals[k];
331331 }
332332 j += row_nnz;
@@ -335,7 +335,7 @@ utility::coo_matrix<value_type> extract_edge_features(scip::Model& model) {
335335 if (scip::get_unshifted_rhs (scip, row).has_value ()) {
336336 for (std::size_t k = 0 ; k < row_nnz; ++k) {
337337 indices (0 , j + k) = i;
338- indices (1 , j + k) = static_cast <std::size_t >(SCIPcolGetLPPos (row_cols[k]));
338+ indices (1 , j + k) = static_cast <std::size_t >(SCIPcolGetVarProbindex (row_cols[k]));
339339 values[j + k] = row_vals[k];
340340 }
341341 j += row_nnz;
@@ -344,8 +344,9 @@ utility::coo_matrix<value_type> extract_edge_features(scip::Model& model) {
344344 }
345345
346346 auto const n_rows = n_ineq_rows (model);
347- auto const n_cols = static_cast <std::size_t >(SCIPgetNLPCols (scip));
348- return {values, indices, {n_rows, n_cols}};
347+ // Change this here for variables
348+ auto const n_vars = static_cast <std::size_t >(SCIPgetNVars (scip));
349+ return {values, indices, {n_rows, n_vars}};
349350}
350351
351352auto is_on_root_node (scip::Model& model) -> bool {
@@ -355,17 +356,18 @@ auto is_on_root_node(scip::Model& model) -> bool {
355356
356357auto extract_observation_fully (scip::Model& model) -> NodeBipartiteObs {
357358 auto obs = NodeBipartiteObs{
358- xmatrix::from_shape ({model.lp_columns ().size (), NodeBipartiteObs::n_column_features}),
359+ // Change this here for variables
360+ xmatrix::from_shape ({model.variables ().size (), NodeBipartiteObs::n_variable_features}),
359361 xmatrix::from_shape ({n_ineq_rows (model), NodeBipartiteObs::n_row_features}),
360362 extract_edge_features (model),
361363 };
362- set_features_for_all_cols (obs.column_features , model, true );
364+ set_features_for_all_vars (obs.variable_features , model, true );
363365 set_features_for_all_rows (obs.row_features , model, true );
364366 return obs;
365367}
366368
367369auto extract_observation_from_cache (scip::Model& model, NodeBipartiteObs obs) -> NodeBipartiteObs {
368- set_features_for_all_cols (obs.column_features , model, false );
370+ set_features_for_all_vars (obs.variable_features , model, false );
369371 set_features_for_all_rows (obs.row_features , model, false );
370372 return obs;
371373}
0 commit comments