22#include < cstddef>
33#include < limits>
44
5+ #include < scip/scip.h>
6+ #include < scip/struct_lp.h>
57#include < xtensor/xview.hpp>
68
79#include " ecole/observation/nodebipartite.hpp"
1113namespace ecole {
1214namespace observation {
1315
16+ namespace {
17+
18+ /* ********************
19+ * Common helpers *
20+ *********************/
21+
1422using tensor = decltype (NodeBipartiteObs::column_features);
1523using value_type = tensor::value_type;
1624
17- static value_type constexpr cste = 5 .;
18- static value_type constexpr nan = std::numeric_limits<value_type>::quiet_NaN();
19- static auto constexpr n_row_feat = 5 ;
20- static auto constexpr n_col_feat =
21- 11 + scip::enum_size<scip::var_type>::value + scip::enum_size<scip::base_stat>::value;
25+ value_type constexpr cste = 5 .;
26+ value_type constexpr nan = std::numeric_limits<value_type>::quiet_NaN();
2227
23- static value_type get_obj_norm ( scip::Model const & model) {
24- auto norm = SCIPgetObjNorm (model. get_scip_ptr () );
28+ scip::real obj_l2_norm (Scip* const scip) noexcept {
29+ auto const norm = SCIPgetObjNorm (scip );
2530 return norm > 0 ? norm : 1 .;
2631}
2732
28- static auto extract_col_feat (scip::Model const & model) {
29- tensor col_feat{{model.lp_columns ().size , n_col_feat}, 0 .};
33+ /* *****************************************
34+ * Column features extraction functions *
35+ ******************************************/
36+
37+ nonstd::optional<scip::real> upper_bound (Scip* const scip, scip::Col* const col) noexcept {
38+ auto const ub_val = SCIPcolGetUb (col);
39+ if (SCIPisInfinity (scip, REALABS (ub_val))) {
40+ return {};
41+ }
42+ return ub_val;
43+ }
44+
45+ nonstd::optional<scip::real> lower_bound (Scip* const scip, scip::Col* const col) noexcept {
46+ auto const lb_val = SCIPcolGetLb (col);
47+ if (SCIPisInfinity (scip, REALABS (lb_val))) {
48+ return {};
49+ }
50+ return lb_val;
51+ }
52+
53+ bool is_prim_sol_at_lb (Scip* const scip, scip::Col* const col) noexcept {
54+ auto const lb_val = lower_bound (scip, col);
55+ if (lb_val) {
56+ return SCIPisEQ (scip, SCIPcolGetPrimsol (col), lb_val.value ());
57+ }
58+ return false ;
59+ }
60+
61+ bool is_prim_sol_at_ub (Scip* const scip, scip::Col* const col) noexcept {
62+ auto const ub_val = upper_bound (scip, col);
63+ if (ub_val) {
64+ return SCIPisEQ (scip, SCIPcolGetPrimsol (col), ub_val.value ());
65+ }
66+ return false ;
67+ }
68+
69+ nonstd::optional<scip::real> best_sol_val (Scip* const scip, scip::Var* const var) noexcept {
70+ auto const sol = SCIPgetBestSol (scip);
71+ if (sol != nullptr ) {
72+ return SCIPgetSolVal (scip, sol, var);
73+ }
74+ return {};
75+ }
76+
77+ nonstd::optional<scip::real> avg_sol (Scip* const scip, scip::Var* const var) noexcept {
78+ if (SCIPgetBestSol (scip) != nullptr ) {
79+ return SCIPvarGetAvgSol (var);
80+ }
81+ return {};
82+ }
83+
84+ nonstd::optional<scip::real>
85+ feas_frac (Scip* const scip, scip::Var* const var, scip::Col* const col) noexcept {
86+ if (SCIPvarGetType (var) == SCIP_VARTYPE_CONTINUOUS) {
87+ return {};
88+ }
89+ return SCIPfeasFrac (scip, SCIPcolGetPrimsol (col));
90+ }
91+
92+ auto extract_col_feat (scip::Model const & model) {
93+ auto constexpr n_col_feat =
94+ 11 + scip::enum_size<scip::var_type>::value + scip::enum_size<scip::base_stat>::value;
95+ auto const scip = model.get_scip_ptr ();
96+ tensor col_feat{{model.lp_columns ().size (), n_col_feat}, 0 .};
3097
31- value_type const n_lps = static_cast <value_type>(SCIPgetNLPs (model. get_scip_ptr () ));
32- value_type const obj_l2_norm = get_obj_norm (model );
98+ value_type const n_lps = static_cast <value_type>(SCIPgetNLPs (scip ));
99+ value_type const obj_norm = obj_l2_norm (scip );
33100
34101 auto iter = col_feat.begin ();
35102 for (auto const col : model.lp_columns ()) {
36- auto const var = col.var ();
37- *(iter++) = static_cast <value_type>(col.lb ().has_value ());
38- *(iter++) = static_cast <value_type>(col.ub ().has_value ());
39- *(iter++) = col.reduced_cost () / obj_l2_norm;
40- *(iter++) = col.obj () / obj_l2_norm;
41- *(iter++) = col.prim_sol ();
42- if (var.type_ () == SCIP_VARTYPE_CONTINUOUS)
43- *(iter++) = 0 .;
44- else
45- *(iter++) = col.prim_sol_frac ();
46- *(iter++) = static_cast <value_type>(col.is_prim_sol_at_lb ());
47- *(iter++) = static_cast <value_type>(col.is_prim_sol_at_ub ());
48- *(iter++) = static_cast <value_type>(col.age ()) / (n_lps + cste);
49- iter[static_cast <std::size_t >(col.basis_status ())] = 1 .;
103+ auto const var = SCIPcolGetVar (col);
104+ *(iter++) = static_cast <value_type>(lower_bound (scip, col).has_value ());
105+ *(iter++) = static_cast <value_type>(upper_bound (scip, col).has_value ());
106+ *(iter++) = SCIPgetColRedcost (scip, col) / obj_norm;
107+ *(iter++) = SCIPcolGetObj (col) / obj_norm;
108+ *(iter++) = SCIPcolGetPrimsol (col);
109+ *(iter++) = feas_frac (scip, var, col).value_or (0 .);
110+ *(iter++) = static_cast <value_type>(is_prim_sol_at_lb (scip, col));
111+ *(iter++) = static_cast <value_type>(is_prim_sol_at_ub (scip, col));
112+ *(iter++) = static_cast <value_type>(col->age ) / (n_lps + cste);
113+ iter[static_cast <std::size_t >(SCIPcolGetBasisStatus (col))] = 1 .;
50114 iter += scip::enum_size<scip::base_stat>::value;
51- *(iter++) = var. best_sol_val ().value_or (nan);
52- *(iter++) = var. avg_sol ().value_or (nan);
53- iter[static_cast <std::size_t >(var. type_ ( ))] = 1 .;
115+ *(iter++) = best_sol_val (scip, var ).value_or (nan);
116+ *(iter++) = avg_sol (scip, var ).value_or (nan);
117+ iter[static_cast <std::size_t >(SCIPvarGetType (var ))] = 1 .;
54118 iter += scip::enum_size<scip::var_type>::value;
55119 }
56120
@@ -60,48 +124,94 @@ static auto extract_col_feat(scip::Model const& model) {
60124 return col_feat;
61125}
62126
127+ /* *****************************************
128+ * Column features extraction functions *
129+ ******************************************/
130+
131+ scip::real row_l2_norm (scip::Row* const row) noexcept {
132+ auto const norm = SCIProwGetNorm (row);
133+ return norm > 0 ? norm : 1 .;
134+ }
135+
136+ nonstd::optional<scip::real> left_hand_side (Scip* const scip, scip::Row* const row) noexcept {
137+ auto const lhs_val = SCIProwGetLhs (row);
138+ if (SCIPisInfinity (scip, REALABS (lhs_val))) {
139+ return {};
140+ }
141+ return lhs_val - SCIProwGetConstant (row);
142+ }
143+
144+ nonstd::optional<scip::real> right_hand_side (Scip* const scip, scip::Row* const row) noexcept {
145+ auto const rhs_val = SCIProwGetRhs (row);
146+ if (SCIPisInfinity (scip, REALABS (rhs_val))) {
147+ return {};
148+ }
149+ return rhs_val - SCIProwGetConstant (row);
150+ }
151+
152+ bool is_at_lhs (Scip* const scip, scip::Row* const row) noexcept {
153+ auto const activity = SCIPgetRowLPActivity (scip, row);
154+ auto const lhs_val = SCIProwGetLhs (row);
155+ return SCIPisEQ (scip, activity, lhs_val);
156+ }
157+
158+ bool is_at_rhs (Scip* const scip, scip::Row* const row) noexcept {
159+ auto const activity = SCIPgetRowLPActivity (scip, row);
160+ auto const rhs_val = SCIProwGetRhs (row);
161+ return SCIPisEQ (scip, activity, rhs_val);
162+ }
163+
164+ scip::real obj_cos_sim (Scip* const scip, scip::Row* const row) noexcept {
165+ auto const norm_prod = SCIProwGetNorm (row) * SCIPgetObjNorm (scip);
166+ if (SCIPisPositive (scip, norm_prod)) {
167+ return row->objprod / norm_prod;
168+ }
169+ return 0 .;
170+ }
171+
63172/* *
64173 * Number of inequality rows.
65174 *
66175 * Row are counted once per right hand side and once per left hand side.
67176 */
68- static std::size_t get_n_ineq_rows (scip::Model const & model) {
177+ std::size_t n_ineq_rows (scip::Model const & model) {
178+ auto const scip = model.get_scip_ptr ();
69179 std::size_t count = 0 ;
70180 for (auto row : model.lp_rows ()) {
71- count += static_cast <std::size_t >(row. rhs ( ).has_value ());
72- count += static_cast <std::size_t >(row. lhs ( ).has_value ());
181+ count += static_cast <std::size_t >(left_hand_side (scip, row ).has_value ());
182+ count += static_cast <std::size_t >(right_hand_side (scip, row ).has_value ());
73183 }
74184 return count;
75185}
76186
77- static auto extract_row_feat (scip::Model const & model) {
78- tensor row_feat{{get_n_ineq_rows (model), n_row_feat}, 0 .};
187+ auto extract_row_feat (scip::Model const & model) {
188+ auto constexpr n_row_feat = 5 ;
189+ auto const scip = model.get_scip_ptr ();
190+ tensor row_feat{{n_ineq_rows (model), n_row_feat}, 0 .};
79191
80- value_type const n_lps = static_cast <value_type>(SCIPgetNLPs (model. get_scip_ptr () ));
81- value_type const obj_l2_norm = get_obj_norm (model );
192+ value_type const n_lps = static_cast <value_type>(SCIPgetNLPs (scip ));
193+ value_type const obj_norm = obj_l2_norm (scip );
82194
83- auto extract_row = [n_lps, obj_l2_norm ](auto & iter, auto const row, bool const lhs) {
195+ auto extract_row = [n_lps, obj_norm, scip ](auto & iter, auto const row, bool const lhs) {
84196 value_type const sign = lhs ? -1 . : 1 .;
85- value_type row_l2_norm = static_cast <value_type>(row.l2_norm ());
86- if (row_l2_norm == 0 ) row_l2_norm = 1 .;
87-
197+ value_type row_norm = static_cast <value_type>(row_l2_norm (row));
88198 if (lhs) {
89- *(iter++) = sign * row. lhs ( ).value () / row_l2_norm ;
90- *(iter++) = static_cast <value_type>(row. is_at_lhs ());
199+ *(iter++) = sign * left_hand_side (scip, row ).value () / row_norm ;
200+ *(iter++) = static_cast <value_type>(is_at_lhs (scip, row ));
91201 } else {
92- *(iter++) = sign * row. rhs ( ).value () / row_l2_norm ;
93- *(iter++) = static_cast <value_type>(row. is_at_rhs ());
202+ *(iter++) = sign * right_hand_side (scip, row ).value () / row_norm ;
203+ *(iter++) = static_cast <value_type>(is_at_rhs (scip, row ));
94204 }
95- *(iter++) = static_cast <value_type>(row. age ( )) / (n_lps + cste);
96- *(iter++) = sign * row. obj_cos_sim ();
97- *(iter++) = sign * row. dual_sol ( ) / (row_l2_norm * obj_l2_norm );
205+ *(iter++) = static_cast <value_type>(SCIProwGetAge (row )) / (n_lps + cste);
206+ *(iter++) = sign * obj_cos_sim (scip, row );
207+ *(iter++) = sign * SCIProwGetDualsol (row ) / (row_norm * obj_norm );
98208 };
99209
100210 auto iter_ = row_feat.begin ();
101211 for (auto const row_ : model.lp_rows ()) {
102212 // Rows are counted once per rhs and once per lhs
103- if (row_. lhs ( ).has_value ()) extract_row (iter_, row_, true );
104- if (row_. rhs ( ).has_value ()) extract_row (iter_, row_, false );
213+ if (left_hand_side (scip, row_ ).has_value ()) extract_row (iter_, row_, true );
214+ if (right_hand_side (scip, row_ ).has_value ()) extract_row (iter_, row_, false );
105215 }
106216
107217 // Make sure we iterated over as many element as there are in the tensor
@@ -112,22 +222,27 @@ static auto extract_row_feat(scip::Model const& model) {
112222 return row_feat;
113223}
114224
225+ /* ***************************************
226+ * Edge features extraction functions *
227+ ****************************************/
228+
115229/* *
116230 * Number of non zero element in the constraint matrix.
117231 *
118232 * Row are counted once per right hand side and once per left hand side.
119233 */
120- static auto matrix_nnz (scip::Model const & model) {
234+ auto matrix_nnz (scip::Model const & model) {
235+ auto const scip = model.get_scip_ptr ();
121236 std::size_t nnz = 0 ;
122237 for (auto row : model.lp_rows ()) {
123- auto const row_size = static_cast <std::size_t >(row. n_lp_nonz ( ));
124- if (row. lhs ( ).has_value ()) nnz += row_size;
125- if (row. rhs ( ).has_value ()) nnz += row_size;
238+ auto const row_size = static_cast <std::size_t >(SCIProwGetNLPNonz (row ));
239+ if (left_hand_side (scip, row ).has_value ()) nnz += row_size;
240+ if (right_hand_side (scip, row ).has_value ()) nnz += row_size;
126241 }
127242 return nnz;
128243}
129244
130- static utility::coo_matrix<value_type> extract_edge_feat (scip::Model const & model) {
245+ utility::coo_matrix<value_type> extract_edge_feat (scip::Model const & model) {
131246 auto const scip = model.get_scip_ptr ();
132247
133248 using coo_matrix = utility::coo_matrix<value_type>;
@@ -137,10 +252,10 @@ static utility::coo_matrix<value_type> extract_edge_feat(scip::Model const& mode
137252
138253 std::size_t i = 0 , j = 0 ;
139254 for (auto const row : model.lp_rows ()) {
140- SCIP_COL** const row_cols = SCIProwGetCols (row. value );
141- scip::real const * const row_vals = SCIProwGetVals (row. value );
142- std:: size_t const row_nnz = static_cast <std::size_t >(SCIProwGetNLPNonz (row. value ));
143- if (row. lhs ( ).has_value ()) {
255+ auto const row_cols = SCIProwGetCols (row);
256+ auto const * const row_vals = SCIProwGetVals (row);
257+ auto const row_nnz = static_cast <std::size_t >(SCIProwGetNLPNonz (row));
258+ if (left_hand_side (scip, row ).has_value ()) {
144259 for (std::size_t k = 0 ; k < row_nnz; ++k) {
145260 indices (0 , j + k) = i;
146261 indices (1 , j + k) = static_cast <std::size_t >(SCIPcolGetLPPos (row_cols[k]));
@@ -149,7 +264,7 @@ static utility::coo_matrix<value_type> extract_edge_feat(scip::Model const& mode
149264 j += row_nnz;
150265 i++;
151266 }
152- if (row. rhs ( ).has_value ()) {
267+ if (right_hand_side (scip, row ).has_value ()) {
153268 for (std::size_t k = 0 ; k < row_nnz; ++k) {
154269 indices (0 , j + k) = i;
155270 indices (1 , j + k) = static_cast <std::size_t >(SCIPcolGetLPPos (row_cols[k]));
@@ -160,18 +275,23 @@ static utility::coo_matrix<value_type> extract_edge_feat(scip::Model const& mode
160275 }
161276 }
162277
163- auto const n_rows = get_n_ineq_rows (model);
278+ auto const n_rows = n_ineq_rows (model);
164279 auto const n_cols = static_cast <std::size_t >(SCIPgetNLPCols (scip));
165280 return {values, indices, {n_rows, n_cols}};
166281}
167282
283+ } // namespace
284+
285+ /* ************************************
286+ * Observation extracting function *
287+ *************************************/
288+
168289auto NodeBipartite::obtain_observation (scip::Model& model) -> nonstd::optional<NodeBipartiteObs> {
169290 if (model.get_stage () == SCIP_STAGE_SOLVING) {
170291 return NodeBipartiteObs{
171292 extract_col_feat (model), extract_row_feat (model), extract_edge_feat (model)};
172- } else {
173- return {};
174293 }
294+ return {};
175295}
176296
177297} // namespace observation
0 commit comments