Skip to content

Commit be49daf

Browse files
committed
Adapt NodeBipartite not to use Col/Row views
1 parent 3184ec7 commit be49daf

File tree

1 file changed

+181
-61
lines changed

1 file changed

+181
-61
lines changed

libecole/src/observation/nodebipartite.cpp

Lines changed: 181 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#include <cstddef>
33
#include <limits>
44

5+
#include <scip/scip.h>
6+
#include <scip/struct_lp.h>
57
#include <xtensor/xview.hpp>
68

79
#include "ecole/observation/nodebipartite.hpp"
@@ -11,46 +13,108 @@
1113
namespace ecole {
1214
namespace observation {
1315

16+
namespace {
17+
18+
/*********************
19+
* Common helpers *
20+
*********************/
21+
1422
using tensor = decltype(NodeBipartiteObs::column_features);
1523
using value_type = tensor::value_type;
1624

17-
static value_type constexpr cste = 5.;
18-
static value_type constexpr nan = std::numeric_limits<value_type>::quiet_NaN();
19-
static auto constexpr n_row_feat = 5;
20-
static auto constexpr n_col_feat =
21-
11 + scip::enum_size<scip::var_type>::value + scip::enum_size<scip::base_stat>::value;
25+
value_type constexpr cste = 5.;
26+
value_type constexpr nan = std::numeric_limits<value_type>::quiet_NaN();
2227

23-
static value_type get_obj_norm(scip::Model const& model) {
24-
auto norm = SCIPgetObjNorm(model.get_scip_ptr());
28+
scip::real obj_l2_norm(Scip* const scip) noexcept {
29+
auto const norm = SCIPgetObjNorm(scip);
2530
return norm > 0 ? norm : 1.;
2631
}
2732

28-
static auto extract_col_feat(scip::Model const& model) {
29-
tensor col_feat{{model.lp_columns().size, n_col_feat}, 0.};
33+
/******************************************
34+
* Column features extraction functions *
35+
******************************************/
36+
37+
nonstd::optional<scip::real> upper_bound(Scip* const scip, scip::Col* const col) noexcept {
38+
auto const ub_val = SCIPcolGetUb(col);
39+
if (SCIPisInfinity(scip, REALABS(ub_val))) {
40+
return {};
41+
}
42+
return ub_val;
43+
}
44+
45+
nonstd::optional<scip::real> lower_bound(Scip* const scip, scip::Col* const col) noexcept {
46+
auto const lb_val = SCIPcolGetLb(col);
47+
if (SCIPisInfinity(scip, REALABS(lb_val))) {
48+
return {};
49+
}
50+
return lb_val;
51+
}
52+
53+
bool is_prim_sol_at_lb(Scip* const scip, scip::Col* const col) noexcept {
54+
auto const lb_val = lower_bound(scip, col);
55+
if (lb_val) {
56+
return SCIPisEQ(scip, SCIPcolGetPrimsol(col), lb_val.value());
57+
}
58+
return false;
59+
}
60+
61+
bool is_prim_sol_at_ub(Scip* const scip, scip::Col* const col) noexcept {
62+
auto const ub_val = upper_bound(scip, col);
63+
if (ub_val) {
64+
return SCIPisEQ(scip, SCIPcolGetPrimsol(col), ub_val.value());
65+
}
66+
return false;
67+
}
68+
69+
nonstd::optional<scip::real> best_sol_val(Scip* const scip, scip::Var* const var) noexcept {
70+
auto const sol = SCIPgetBestSol(scip);
71+
if (sol != nullptr) {
72+
return SCIPgetSolVal(scip, sol, var);
73+
}
74+
return {};
75+
}
76+
77+
nonstd::optional<scip::real> avg_sol(Scip* const scip, scip::Var* const var) noexcept {
78+
if (SCIPgetBestSol(scip) != nullptr) {
79+
return SCIPvarGetAvgSol(var);
80+
}
81+
return {};
82+
}
83+
84+
nonstd::optional<scip::real>
85+
feas_frac(Scip* const scip, scip::Var* const var, scip::Col* const col) noexcept {
86+
if (SCIPvarGetType(var) == SCIP_VARTYPE_CONTINUOUS) {
87+
return {};
88+
}
89+
return SCIPfeasFrac(scip, SCIPcolGetPrimsol(col));
90+
}
91+
92+
auto extract_col_feat(scip::Model const& model) {
93+
auto constexpr n_col_feat =
94+
11 + scip::enum_size<scip::var_type>::value + scip::enum_size<scip::base_stat>::value;
95+
auto const scip = model.get_scip_ptr();
96+
tensor col_feat{{model.lp_columns().size(), n_col_feat}, 0.};
3097

31-
value_type const n_lps = static_cast<value_type>(SCIPgetNLPs(model.get_scip_ptr()));
32-
value_type const obj_l2_norm = get_obj_norm(model);
98+
value_type const n_lps = static_cast<value_type>(SCIPgetNLPs(scip));
99+
value_type const obj_norm = obj_l2_norm(scip);
33100

34101
auto iter = col_feat.begin();
35102
for (auto const col : model.lp_columns()) {
36-
auto const var = col.var();
37-
*(iter++) = static_cast<value_type>(col.lb().has_value());
38-
*(iter++) = static_cast<value_type>(col.ub().has_value());
39-
*(iter++) = col.reduced_cost() / obj_l2_norm;
40-
*(iter++) = col.obj() / obj_l2_norm;
41-
*(iter++) = col.prim_sol();
42-
if (var.type_() == SCIP_VARTYPE_CONTINUOUS)
43-
*(iter++) = 0.;
44-
else
45-
*(iter++) = col.prim_sol_frac();
46-
*(iter++) = static_cast<value_type>(col.is_prim_sol_at_lb());
47-
*(iter++) = static_cast<value_type>(col.is_prim_sol_at_ub());
48-
*(iter++) = static_cast<value_type>(col.age()) / (n_lps + cste);
49-
iter[static_cast<std::size_t>(col.basis_status())] = 1.;
103+
auto const var = SCIPcolGetVar(col);
104+
*(iter++) = static_cast<value_type>(lower_bound(scip, col).has_value());
105+
*(iter++) = static_cast<value_type>(upper_bound(scip, col).has_value());
106+
*(iter++) = SCIPgetColRedcost(scip, col) / obj_norm;
107+
*(iter++) = SCIPcolGetObj(col) / obj_norm;
108+
*(iter++) = SCIPcolGetPrimsol(col);
109+
*(iter++) = feas_frac(scip, var, col).value_or(0.);
110+
*(iter++) = static_cast<value_type>(is_prim_sol_at_lb(scip, col));
111+
*(iter++) = static_cast<value_type>(is_prim_sol_at_ub(scip, col));
112+
*(iter++) = static_cast<value_type>(col->age) / (n_lps + cste);
113+
iter[static_cast<std::size_t>(SCIPcolGetBasisStatus(col))] = 1.;
50114
iter += scip::enum_size<scip::base_stat>::value;
51-
*(iter++) = var.best_sol_val().value_or(nan);
52-
*(iter++) = var.avg_sol().value_or(nan);
53-
iter[static_cast<std::size_t>(var.type_())] = 1.;
115+
*(iter++) = best_sol_val(scip, var).value_or(nan);
116+
*(iter++) = avg_sol(scip, var).value_or(nan);
117+
iter[static_cast<std::size_t>(SCIPvarGetType(var))] = 1.;
54118
iter += scip::enum_size<scip::var_type>::value;
55119
}
56120

@@ -60,48 +124,94 @@ static auto extract_col_feat(scip::Model const& model) {
60124
return col_feat;
61125
}
62126

127+
/******************************************
128+
* Column features extraction functions *
129+
******************************************/
130+
131+
scip::real row_l2_norm(scip::Row* const row) noexcept {
132+
auto const norm = SCIProwGetNorm(row);
133+
return norm > 0 ? norm : 1.;
134+
}
135+
136+
nonstd::optional<scip::real> left_hand_side(Scip* const scip, scip::Row* const row) noexcept {
137+
auto const lhs_val = SCIProwGetLhs(row);
138+
if (SCIPisInfinity(scip, REALABS(lhs_val))) {
139+
return {};
140+
}
141+
return lhs_val - SCIProwGetConstant(row);
142+
}
143+
144+
nonstd::optional<scip::real> right_hand_side(Scip* const scip, scip::Row* const row) noexcept {
145+
auto const rhs_val = SCIProwGetRhs(row);
146+
if (SCIPisInfinity(scip, REALABS(rhs_val))) {
147+
return {};
148+
}
149+
return rhs_val - SCIProwGetConstant(row);
150+
}
151+
152+
bool is_at_lhs(Scip* const scip, scip::Row* const row) noexcept {
153+
auto const activity = SCIPgetRowLPActivity(scip, row);
154+
auto const lhs_val = SCIProwGetLhs(row);
155+
return SCIPisEQ(scip, activity, lhs_val);
156+
}
157+
158+
bool is_at_rhs(Scip* const scip, scip::Row* const row) noexcept {
159+
auto const activity = SCIPgetRowLPActivity(scip, row);
160+
auto const rhs_val = SCIProwGetRhs(row);
161+
return SCIPisEQ(scip, activity, rhs_val);
162+
}
163+
164+
scip::real obj_cos_sim(Scip* const scip, scip::Row* const row) noexcept {
165+
auto const norm_prod = SCIProwGetNorm(row) * SCIPgetObjNorm(scip);
166+
if (SCIPisPositive(scip, norm_prod)) {
167+
return row->objprod / norm_prod;
168+
}
169+
return 0.;
170+
}
171+
63172
/**
64173
* Number of inequality rows.
65174
*
66175
* Row are counted once per right hand side and once per left hand side.
67176
*/
68-
static std::size_t get_n_ineq_rows(scip::Model const& model) {
177+
std::size_t n_ineq_rows(scip::Model const& model) {
178+
auto const scip = model.get_scip_ptr();
69179
std::size_t count = 0;
70180
for (auto row : model.lp_rows()) {
71-
count += static_cast<std::size_t>(row.rhs().has_value());
72-
count += static_cast<std::size_t>(row.lhs().has_value());
181+
count += static_cast<std::size_t>(left_hand_side(scip, row).has_value());
182+
count += static_cast<std::size_t>(right_hand_side(scip, row).has_value());
73183
}
74184
return count;
75185
}
76186

77-
static auto extract_row_feat(scip::Model const& model) {
78-
tensor row_feat{{get_n_ineq_rows(model), n_row_feat}, 0.};
187+
auto extract_row_feat(scip::Model const& model) {
188+
auto constexpr n_row_feat = 5;
189+
auto const scip = model.get_scip_ptr();
190+
tensor row_feat{{n_ineq_rows(model), n_row_feat}, 0.};
79191

80-
value_type const n_lps = static_cast<value_type>(SCIPgetNLPs(model.get_scip_ptr()));
81-
value_type const obj_l2_norm = get_obj_norm(model);
192+
value_type const n_lps = static_cast<value_type>(SCIPgetNLPs(scip));
193+
value_type const obj_norm = obj_l2_norm(scip);
82194

83-
auto extract_row = [n_lps, obj_l2_norm](auto& iter, auto const row, bool const lhs) {
195+
auto extract_row = [n_lps, obj_norm, scip](auto& iter, auto const row, bool const lhs) {
84196
value_type const sign = lhs ? -1. : 1.;
85-
value_type row_l2_norm = static_cast<value_type>(row.l2_norm());
86-
if (row_l2_norm == 0) row_l2_norm = 1.;
87-
197+
value_type row_norm = static_cast<value_type>(row_l2_norm(row));
88198
if (lhs) {
89-
*(iter++) = sign * row.lhs().value() / row_l2_norm;
90-
*(iter++) = static_cast<value_type>(row.is_at_lhs());
199+
*(iter++) = sign * left_hand_side(scip, row).value() / row_norm;
200+
*(iter++) = static_cast<value_type>(is_at_lhs(scip, row));
91201
} else {
92-
*(iter++) = sign * row.rhs().value() / row_l2_norm;
93-
*(iter++) = static_cast<value_type>(row.is_at_rhs());
202+
*(iter++) = sign * right_hand_side(scip, row).value() / row_norm;
203+
*(iter++) = static_cast<value_type>(is_at_rhs(scip, row));
94204
}
95-
*(iter++) = static_cast<value_type>(row.age()) / (n_lps + cste);
96-
*(iter++) = sign * row.obj_cos_sim();
97-
*(iter++) = sign * row.dual_sol() / (row_l2_norm * obj_l2_norm);
205+
*(iter++) = static_cast<value_type>(SCIProwGetAge(row)) / (n_lps + cste);
206+
*(iter++) = sign * obj_cos_sim(scip, row);
207+
*(iter++) = sign * SCIProwGetDualsol(row) / (row_norm * obj_norm);
98208
};
99209

100210
auto iter_ = row_feat.begin();
101211
for (auto const row_ : model.lp_rows()) {
102212
// Rows are counted once per rhs and once per lhs
103-
if (row_.lhs().has_value()) extract_row(iter_, row_, true);
104-
if (row_.rhs().has_value()) extract_row(iter_, row_, false);
213+
if (left_hand_side(scip, row_).has_value()) extract_row(iter_, row_, true);
214+
if (right_hand_side(scip, row_).has_value()) extract_row(iter_, row_, false);
105215
}
106216

107217
// Make sure we iterated over as many element as there are in the tensor
@@ -112,22 +222,27 @@ static auto extract_row_feat(scip::Model const& model) {
112222
return row_feat;
113223
}
114224

225+
/****************************************
226+
* Edge features extraction functions *
227+
****************************************/
228+
115229
/**
116230
* Number of non zero element in the constraint matrix.
117231
*
118232
* Row are counted once per right hand side and once per left hand side.
119233
*/
120-
static auto matrix_nnz(scip::Model const& model) {
234+
auto matrix_nnz(scip::Model const& model) {
235+
auto const scip = model.get_scip_ptr();
121236
std::size_t nnz = 0;
122237
for (auto row : model.lp_rows()) {
123-
auto const row_size = static_cast<std::size_t>(row.n_lp_nonz());
124-
if (row.lhs().has_value()) nnz += row_size;
125-
if (row.rhs().has_value()) nnz += row_size;
238+
auto const row_size = static_cast<std::size_t>(SCIProwGetNLPNonz(row));
239+
if (left_hand_side(scip, row).has_value()) nnz += row_size;
240+
if (right_hand_side(scip, row).has_value()) nnz += row_size;
126241
}
127242
return nnz;
128243
}
129244

130-
static utility::coo_matrix<value_type> extract_edge_feat(scip::Model const& model) {
245+
utility::coo_matrix<value_type> extract_edge_feat(scip::Model const& model) {
131246
auto const scip = model.get_scip_ptr();
132247

133248
using coo_matrix = utility::coo_matrix<value_type>;
@@ -137,10 +252,10 @@ static utility::coo_matrix<value_type> extract_edge_feat(scip::Model const& mode
137252

138253
std::size_t i = 0, j = 0;
139254
for (auto const row : model.lp_rows()) {
140-
SCIP_COL** const row_cols = SCIProwGetCols(row.value);
141-
scip::real const* const row_vals = SCIProwGetVals(row.value);
142-
std::size_t const row_nnz = static_cast<std::size_t>(SCIProwGetNLPNonz(row.value));
143-
if (row.lhs().has_value()) {
255+
auto const row_cols = SCIProwGetCols(row);
256+
auto const* const row_vals = SCIProwGetVals(row);
257+
auto const row_nnz = static_cast<std::size_t>(SCIProwGetNLPNonz(row));
258+
if (left_hand_side(scip, row).has_value()) {
144259
for (std::size_t k = 0; k < row_nnz; ++k) {
145260
indices(0, j + k) = i;
146261
indices(1, j + k) = static_cast<std::size_t>(SCIPcolGetLPPos(row_cols[k]));
@@ -149,7 +264,7 @@ static utility::coo_matrix<value_type> extract_edge_feat(scip::Model const& mode
149264
j += row_nnz;
150265
i++;
151266
}
152-
if (row.rhs().has_value()) {
267+
if (right_hand_side(scip, row).has_value()) {
153268
for (std::size_t k = 0; k < row_nnz; ++k) {
154269
indices(0, j + k) = i;
155270
indices(1, j + k) = static_cast<std::size_t>(SCIPcolGetLPPos(row_cols[k]));
@@ -160,18 +275,23 @@ static utility::coo_matrix<value_type> extract_edge_feat(scip::Model const& mode
160275
}
161276
}
162277

163-
auto const n_rows = get_n_ineq_rows(model);
278+
auto const n_rows = n_ineq_rows(model);
164279
auto const n_cols = static_cast<std::size_t>(SCIPgetNLPCols(scip));
165280
return {values, indices, {n_rows, n_cols}};
166281
}
167282

283+
} // namespace
284+
285+
/*************************************
286+
* Observation extracting function *
287+
*************************************/
288+
168289
auto NodeBipartite::obtain_observation(scip::Model& model) -> nonstd::optional<NodeBipartiteObs> {
169290
if (model.get_stage() == SCIP_STAGE_SOLVING) {
170291
return NodeBipartiteObs{
171292
extract_col_feat(model), extract_row_feat(model), extract_edge_feat(model)};
172-
} else {
173-
return {};
174293
}
294+
return {};
175295
}
176296

177297
} // namespace observation

0 commit comments

Comments
 (0)