Skip to content

Commit 0e371d8

Browse files
dchetelatAntoinePrv
authored andcommitted
Rewrote NodeBipartite in terms of variables rather than columns
1 parent beea5ed commit 0e371d8

File tree

5 files changed

+91
-89
lines changed

5 files changed

+91
-89
lines changed

libecole/include/ecole/observation/nodebipartite.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@ namespace ecole::observation {
1212
struct NodeBipartiteObs {
1313
using value_type = double;
1414

15-
static inline std::size_t constexpr n_static_column_features = 5;
16-
static inline std::size_t constexpr n_dynamic_column_features = 14;
17-
static inline std::size_t constexpr n_column_features = n_static_column_features + n_dynamic_column_features;
18-
enum struct ColumnFeatures : std::size_t {
15+
static inline std::size_t constexpr n_static_variable_features = 5;
16+
static inline std::size_t constexpr n_dynamic_variable_features = 14;
17+
static inline std::size_t constexpr n_variable_features = n_static_variable_features + n_dynamic_variable_features;
18+
enum struct VariableFeatures : std::size_t {
1919
/** Static features */
2020
objective = 0,
2121
is_type_binary, // One hot encoded
@@ -54,7 +54,7 @@ struct NodeBipartiteObs {
5454
scaled_age,
5555
};
5656

57-
xt::xtensor<value_type, 2> column_features;
57+
xt::xtensor<value_type, 2> variable_features;
5858
xt::xtensor<value_type, 2> row_features;
5959
utility::coo_matrix<value_type> edge_features;
6060
};

libecole/src/observation/nodebipartite.cpp

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ namespace {
2020
* Common helpers *
2121
*********************/
2222

23-
using xmatrix = decltype(NodeBipartiteObs::column_features);
23+
using xmatrix = decltype(NodeBipartiteObs::variable_features);
2424
using value_type = xmatrix::value_type;
2525

26-
using ColumnFeatures = NodeBipartiteObs::ColumnFeatures;
26+
using VariableFeatures = NodeBipartiteObs::VariableFeatures;
2727
using RowFeatures = NodeBipartiteObs::RowFeatures;
2828

2929
value_type constexpr cste = 5.;
@@ -34,9 +34,9 @@ SCIP_Real obj_l2_norm(SCIP* const scip) noexcept {
3434
return norm > 0 ? norm : 1.;
3535
}
3636

37-
/******************************************
38-
* Column features extraction functions *
39-
******************************************/
37+
/*******************************************
38+
* Variable features extraction functions *
39+
*******************************************/
4040

4141
std::optional<SCIP_Real> upper_bound(SCIP* const scip, SCIP_COL* const col) noexcept {
4242
auto const ub_val = SCIPcolGetUb(col);
@@ -85,11 +85,11 @@ std::optional<SCIP_Real> avg_sol(SCIP* const scip, SCIP_VAR* const var) noexcept
8585
return {};
8686
}
8787

88-
std::optional<SCIP_Real> feas_frac(SCIP* const scip, SCIP_VAR* const var, SCIP_COL* const col) noexcept {
88+
std::optional<SCIP_Real> feas_frac(SCIP* const scip, SCIP_VAR* const var) noexcept {
8989
if (SCIPvarGetType(var) == SCIP_VARTYPE_CONTINUOUS) {
9090
return {};
9191
}
92-
return SCIPfeasFrac(scip, SCIPcolGetPrimsol(col));
92+
return SCIPfeasFrac(scip, SCIPvarGetLPSol(var));
9393
}
9494

9595
/** Convert an enum to its underlying index. */
@@ -98,89 +98,89 @@ template <typename E> constexpr auto idx(E e) {
9898
}
9999

100100
template <typename Features>
101-
void set_static_features_for_col(Features&& out, SCIP_VAR* const var, SCIP_COL* const col, value_type obj_norm) {
102-
out[idx(ColumnFeatures::objective)] = SCIPcolGetObj(col) / obj_norm;
101+
void set_static_features_for_var(Features&& out, SCIP_VAR* const var, value_type obj_norm) {
102+
out[idx(VariableFeatures::objective)] = SCIPvarGetObj(var) / obj_norm;
103103
// On-hot enconding of varaible type
104-
out[idx(ColumnFeatures::is_type_binary)] = 0.;
105-
out[idx(ColumnFeatures::is_type_integer)] = 0.;
106-
out[idx(ColumnFeatures::is_type_implicit_integer)] = 0.;
107-
out[idx(ColumnFeatures::is_type_continuous)] = 0.;
104+
out[idx(VariableFeatures::is_type_binary)] = 0.;
105+
out[idx(VariableFeatures::is_type_integer)] = 0.;
106+
out[idx(VariableFeatures::is_type_implicit_integer)] = 0.;
107+
out[idx(VariableFeatures::is_type_continuous)] = 0.;
108108
switch (SCIPvarGetType(var)) {
109109
case SCIP_VARTYPE_BINARY:
110-
out[idx(ColumnFeatures::is_type_binary)] = 1.;
110+
out[idx(VariableFeatures::is_type_binary)] = 1.;
111111
break;
112112
case SCIP_VARTYPE_INTEGER:
113-
out[idx(ColumnFeatures::is_type_integer)] = 1.;
113+
out[idx(VariableFeatures::is_type_integer)] = 1.;
114114
break;
115115
case SCIP_VARTYPE_IMPLINT:
116-
out[idx(ColumnFeatures::is_type_implicit_integer)] = 1.;
116+
out[idx(VariableFeatures::is_type_implicit_integer)] = 1.;
117117
break;
118118
case SCIP_VARTYPE_CONTINUOUS:
119-
out[idx(ColumnFeatures::is_type_continuous)] = 1.;
119+
out[idx(VariableFeatures::is_type_continuous)] = 1.;
120120
break;
121121
default:
122122
assert(false); // All enum cases must be handled
123123
}
124124
}
125125

126126
template <typename Features>
127-
void set_dynamic_features_for_col(
127+
void set_dynamic_features_for_var(
128128
Features&& out,
129129
SCIP* const scip,
130130
SCIP_VAR* const var,
131131
SCIP_COL* const col,
132132
value_type obj_norm,
133133
value_type n_lps) {
134-
out[idx(ColumnFeatures::has_lower_bound)] = static_cast<value_type>(lower_bound(scip, col).has_value());
135-
out[idx(ColumnFeatures::has_upper_bound)] = static_cast<value_type>(upper_bound(scip, col).has_value());
136-
out[idx(ColumnFeatures::normed_reduced_cost)] = SCIPgetColRedcost(scip, col) / obj_norm;
137-
out[idx(ColumnFeatures::solution_value)] = SCIPcolGetPrimsol(col);
138-
out[idx(ColumnFeatures::solution_frac)] = feas_frac(scip, var, col).value_or(0.);
139-
out[idx(ColumnFeatures::is_solution_at_lower_bound)] = static_cast<value_type>(is_prim_sol_at_lb(scip, col));
140-
out[idx(ColumnFeatures::is_solution_at_upper_bound)] = static_cast<value_type>(is_prim_sol_at_ub(scip, col));
141-
out[idx(ColumnFeatures::scaled_age)] = static_cast<value_type>(SCIPcolGetAge(col)) / (n_lps + cste);
142-
out[idx(ColumnFeatures::incumbent_value)] = best_sol_val(scip, var).value_or(nan);
143-
out[idx(ColumnFeatures::average_incumbent_value)] = avg_sol(scip, var).value_or(nan);
134+
out[idx(VariableFeatures::has_lower_bound)] = static_cast<value_type>(lower_bound(scip, col).has_value());
135+
out[idx(VariableFeatures::has_upper_bound)] = static_cast<value_type>(upper_bound(scip, col).has_value());
136+
out[idx(VariableFeatures::normed_reduced_cost)] = SCIPgetVarRedcost(scip, var) / obj_norm;
137+
out[idx(VariableFeatures::solution_value)] = SCIPvarGetLPSol(var);
138+
out[idx(VariableFeatures::solution_frac)] = feas_frac(scip, var).value_or(0.);
139+
out[idx(VariableFeatures::is_solution_at_lower_bound)] = static_cast<value_type>(is_prim_sol_at_lb(scip, col));
140+
out[idx(VariableFeatures::is_solution_at_upper_bound)] = static_cast<value_type>(is_prim_sol_at_ub(scip, col));
141+
out[idx(VariableFeatures::scaled_age)] = static_cast<value_type>(SCIPcolGetAge(col)) / (n_lps + cste);
142+
out[idx(VariableFeatures::incumbent_value)] = best_sol_val(scip, var).value_or(nan);
143+
out[idx(VariableFeatures::average_incumbent_value)] = avg_sol(scip, var).value_or(nan);
144144
// On-hot encoding
145-
out[idx(ColumnFeatures::is_basis_lower)] = 0.;
146-
out[idx(ColumnFeatures::is_basis_basic)] = 0.;
147-
out[idx(ColumnFeatures::is_basis_upper)] = 0.;
148-
out[idx(ColumnFeatures::is_basis_zero)] = 0.;
145+
out[idx(VariableFeatures::is_basis_lower)] = 0.;
146+
out[idx(VariableFeatures::is_basis_basic)] = 0.;
147+
out[idx(VariableFeatures::is_basis_upper)] = 0.;
148+
out[idx(VariableFeatures::is_basis_zero)] = 0.;
149149
switch (SCIPcolGetBasisStatus(col)) {
150150
case SCIP_BASESTAT_LOWER:
151-
out[idx(ColumnFeatures::is_basis_lower)] = 1.;
151+
out[idx(VariableFeatures::is_basis_lower)] = 1.;
152152
break;
153153
case SCIP_BASESTAT_BASIC:
154-
out[idx(ColumnFeatures::is_basis_basic)] = 1.;
154+
out[idx(VariableFeatures::is_basis_basic)] = 1.;
155155
break;
156156
case SCIP_BASESTAT_UPPER:
157-
out[idx(ColumnFeatures::is_basis_upper)] = 1.;
157+
out[idx(VariableFeatures::is_basis_upper)] = 1.;
158158
break;
159159
case SCIP_BASESTAT_ZERO:
160-
out[idx(ColumnFeatures::is_basis_zero)] = 1.;
160+
out[idx(VariableFeatures::is_basis_zero)] = 1.;
161161
break;
162162
default:
163163
assert(false); // All enum cases must be handled
164164
}
165165
}
166166

167-
void set_features_for_all_cols(xmatrix& out, scip::Model& model, bool const update_static) {
167+
void set_features_for_all_vars(xmatrix& out, scip::Model& model, bool const update_static) {
168168
auto* const scip = model.get_scip_ptr();
169169

170170
// Contant reused in every iterations
171171
auto const n_lps = static_cast<value_type>(SCIPgetNLPs(scip));
172172
auto const obj_norm = obj_l2_norm(scip);
173173

174-
auto const columns = model.lp_columns();
175-
auto const n_columns = columns.size();
176-
for (std::size_t col_idx = 0; col_idx < n_columns; ++col_idx) {
177-
auto* const col = columns[col_idx];
178-
auto* const var = SCIPcolGetVar(col);
179-
auto features = xt::row(out, static_cast<std::ptrdiff_t>(col_idx));
174+
auto const variables = model.variables();
175+
auto const n_vars = variables.size();
176+
for (std::size_t var_idx = 0; var_idx < n_vars; ++var_idx) {
177+
auto* const var = variables[var_idx];
178+
auto* const col = SCIPvarGetCol(var);
179+
auto features = xt::row(out, static_cast<std::ptrdiff_t>(var_idx));
180180
if (update_static) {
181-
set_static_features_for_col(features, var, col, obj_norm);
181+
set_static_features_for_var(features, var, obj_norm);
182182
}
183-
set_dynamic_features_for_col(features, scip, var, col, obj_norm, n_lps);
183+
set_dynamic_features_for_var(features, scip, var, col, obj_norm, n_lps);
184184
}
185185
}
186186

@@ -326,7 +326,7 @@ utility::coo_matrix<value_type> extract_edge_features(scip::Model& model) {
326326
if (scip::get_unshifted_lhs(scip, row).has_value()) {
327327
for (std::size_t k = 0; k < row_nnz; ++k) {
328328
indices(0, j + k) = i;
329-
indices(1, j + k) = static_cast<std::size_t>(SCIPcolGetLPPos(row_cols[k]));
329+
indices(1, j + k) = static_cast<std::size_t>(SCIPcolGetVarProbindex(row_cols[k]));
330330
values[j + k] = -row_vals[k];
331331
}
332332
j += row_nnz;
@@ -335,7 +335,7 @@ utility::coo_matrix<value_type> extract_edge_features(scip::Model& model) {
335335
if (scip::get_unshifted_rhs(scip, row).has_value()) {
336336
for (std::size_t k = 0; k < row_nnz; ++k) {
337337
indices(0, j + k) = i;
338-
indices(1, j + k) = static_cast<std::size_t>(SCIPcolGetLPPos(row_cols[k]));
338+
indices(1, j + k) = static_cast<std::size_t>(SCIPcolGetVarProbindex(row_cols[k]));
339339
values[j + k] = row_vals[k];
340340
}
341341
j += row_nnz;
@@ -344,8 +344,9 @@ utility::coo_matrix<value_type> extract_edge_features(scip::Model& model) {
344344
}
345345

346346
auto const n_rows = n_ineq_rows(model);
347-
auto const n_cols = static_cast<std::size_t>(SCIPgetNLPCols(scip));
348-
return {values, indices, {n_rows, n_cols}};
347+
// Change this here for variables
348+
auto const n_vars = static_cast<std::size_t>(SCIPgetNVars(scip));
349+
return {values, indices, {n_rows, n_vars}};
349350
}
350351

351352
auto is_on_root_node(scip::Model& model) -> bool {
@@ -355,17 +356,18 @@ auto is_on_root_node(scip::Model& model) -> bool {
355356

356357
auto extract_observation_fully(scip::Model& model) -> NodeBipartiteObs {
357358
auto obs = NodeBipartiteObs{
358-
xmatrix::from_shape({model.lp_columns().size(), NodeBipartiteObs::n_column_features}),
359+
// Change this here for variables
360+
xmatrix::from_shape({model.variables().size(), NodeBipartiteObs::n_variable_features}),
359361
xmatrix::from_shape({n_ineq_rows(model), NodeBipartiteObs::n_row_features}),
360362
extract_edge_features(model),
361363
};
362-
set_features_for_all_cols(obs.column_features, model, true);
364+
set_features_for_all_vars(obs.variable_features, model, true);
363365
set_features_for_all_rows(obs.row_features, model, true);
364366
return obs;
365367
}
366368

367369
auto extract_observation_from_cache(scip::Model& model, NodeBipartiteObs obs) -> NodeBipartiteObs {
368-
set_features_for_all_cols(obs.column_features, model, false);
370+
set_features_for_all_vars(obs.variable_features, model, false);
369371
set_features_for_all_rows(obs.row_features, model, false);
370372
return obs;
371373
}

libecole/tests/src/observation/test-nodebipartite.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,30 +30,30 @@ TEST_CASE("NodeBipartite return correct observation", "[obs]") {
3030

3131
SECTION("Observation features are not empty") {
3232
auto const& obs = optional_obs.value();
33-
REQUIRE(obs.column_features.size() > 0);
33+
REQUIRE(obs.variable_features.size() > 0);
3434
REQUIRE(obs.row_features.size() > 0);
3535
REQUIRE(obs.edge_features.nnz() > 0);
3636
}
3737

3838
SECTION("Observation features have matching shape") {
3939
auto const& obs = optional_obs.value();
4040
REQUIRE(obs.row_features.shape()[0] == obs.edge_features.shape[0]);
41-
REQUIRE(obs.column_features.shape()[0] == obs.edge_features.shape[1]);
41+
REQUIRE(obs.variable_features.shape()[0] == obs.edge_features.shape[1]);
4242
REQUIRE(obs.edge_features.indices.shape()[0] == 2);
4343
REQUIRE(obs.edge_features.indices.shape()[1] == obs.edge_features.nnz());
4444
}
4545

46-
SECTION("Columns features are not all nan") {
47-
auto const& col_feat = optional_obs.value().column_features;
48-
for (std::size_t i = 0; i < col_feat.shape()[1]; ++i) {
49-
REQUIRE_FALSE(xt::all(xt::isnan(xt::col(col_feat, static_cast<std::ptrdiff_t>(i)))));
46+
SECTION("Variable features are not all nan") {
47+
auto const& var_feat = optional_obs.value().variable_features;
48+
for (std::size_t i = 0; i < var_feat.shape()[1]; ++i) {
49+
REQUIRE_FALSE(xt::all(xt::isnan(xt::col(var_feat, static_cast<std::ptrdiff_t>(i)))));
5050
}
5151
}
5252

5353
SECTION("Row features are not all nan") {
5454
auto const& row_feat = optional_obs.value().row_features;
5555
for (std::size_t i = 0; i < row_feat.shape()[1]; ++i) {
56-
REQUIRE_FALSE(xt::all(xt::isnan(xt::row(row_feat, static_cast<std::ptrdiff_t>(i)))));
56+
REQUIRE_FALSE(xt::all(xt::isnan(xt::col(row_feat, static_cast<std::ptrdiff_t>(i)))));
5757
}
5858
}
5959
}

python/src/ecole/core/observation.cpp

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ void bind_submodule(py::module_ const& m) {
9393
Each edge is associated with the coefficient of the variable in the constraint.
9494
)")
9595
.def_auto_copy()
96-
.def_auto_pickle(std::array{"column_features", "row_features", "edge_features"})
96+
.def_auto_pickle(std::array{"variable_features", "row_features", "edge_features"})
9797
.def_readwrite_xtensor(
98-
"column_features",
99-
&NodeBipartiteObs::column_features,
98+
"variable_features",
99+
&NodeBipartiteObs::variable_features,
100100
"A matrix where each row is represents a variable, and each column a feature of the variables.")
101101
.def_readwrite_xtensor(
102102
"row_features",
@@ -108,26 +108,26 @@ void bind_submodule(py::module_ const& m) {
108108
"The constraint matrix of the optimization problem, with rows for contraints and "
109109
"columns for variables.");
110110

111-
py::enum_<NodeBipartiteObs::ColumnFeatures>(node_bipartite_obs, "ColumnFeatures")
112-
.value("objective", NodeBipartiteObs::ColumnFeatures::objective)
113-
.value("is_type_binary", NodeBipartiteObs::ColumnFeatures::is_type_binary)
114-
.value("is_type_integer", NodeBipartiteObs::ColumnFeatures::is_type_integer)
115-
.value("is_type_implicit_integer", NodeBipartiteObs::ColumnFeatures::is_type_implicit_integer)
116-
.value("is_type_continuous", NodeBipartiteObs::ColumnFeatures::is_type_continuous)
117-
.value("has_lower_bound", NodeBipartiteObs::ColumnFeatures::has_lower_bound)
118-
.value("has_upper_bound", NodeBipartiteObs::ColumnFeatures::has_upper_bound)
119-
.value("normed_reduced_cost", NodeBipartiteObs::ColumnFeatures::normed_reduced_cost)
120-
.value("solution_value", NodeBipartiteObs::ColumnFeatures::solution_value)
121-
.value("solution_frac", NodeBipartiteObs::ColumnFeatures::solution_frac)
122-
.value("is_solution_at_lower_bound", NodeBipartiteObs::ColumnFeatures::is_solution_at_lower_bound)
123-
.value("is_solution_at_upper_bound", NodeBipartiteObs::ColumnFeatures::is_solution_at_upper_bound)
124-
.value("scaled_age", NodeBipartiteObs::ColumnFeatures::scaled_age)
125-
.value("incumbent_value", NodeBipartiteObs::ColumnFeatures::incumbent_value)
126-
.value("average_incumbent_value", NodeBipartiteObs::ColumnFeatures::average_incumbent_value)
127-
.value("is_basis_lower", NodeBipartiteObs::ColumnFeatures::is_basis_lower)
128-
.value("is_basis_basic", NodeBipartiteObs::ColumnFeatures::is_basis_basic)
129-
.value("is_basis_upper", NodeBipartiteObs::ColumnFeatures::is_basis_upper)
130-
.value("is_basis_zero", NodeBipartiteObs::ColumnFeatures ::is_basis_zero);
111+
py::enum_<NodeBipartiteObs::VariableFeatures>(node_bipartite_obs, "VariableFeatures")
112+
.value("objective", NodeBipartiteObs::VariableFeatures::objective)
113+
.value("is_type_binary", NodeBipartiteObs::VariableFeatures::is_type_binary)
114+
.value("is_type_integer", NodeBipartiteObs::VariableFeatures::is_type_integer)
115+
.value("is_type_implicit_integer", NodeBipartiteObs::VariableFeatures::is_type_implicit_integer)
116+
.value("is_type_continuous", NodeBipartiteObs::VariableFeatures::is_type_continuous)
117+
.value("has_lower_bound", NodeBipartiteObs::VariableFeatures::has_lower_bound)
118+
.value("has_upper_bound", NodeBipartiteObs::VariableFeatures::has_upper_bound)
119+
.value("normed_reduced_cost", NodeBipartiteObs::VariableFeatures::normed_reduced_cost)
120+
.value("solution_value", NodeBipartiteObs::VariableFeatures::solution_value)
121+
.value("solution_frac", NodeBipartiteObs::VariableFeatures::solution_frac)
122+
.value("is_solution_at_lower_bound", NodeBipartiteObs::VariableFeatures::is_solution_at_lower_bound)
123+
.value("is_solution_at_upper_bound", NodeBipartiteObs::VariableFeatures::is_solution_at_upper_bound)
124+
.value("scaled_age", NodeBipartiteObs::VariableFeatures::scaled_age)
125+
.value("incumbent_value", NodeBipartiteObs::VariableFeatures::incumbent_value)
126+
.value("average_incumbent_value", NodeBipartiteObs::VariableFeatures::average_incumbent_value)
127+
.value("is_basis_lower", NodeBipartiteObs::VariableFeatures::is_basis_lower)
128+
.value("is_basis_basic", NodeBipartiteObs::VariableFeatures::is_basis_basic)
129+
.value("is_basis_upper", NodeBipartiteObs::VariableFeatures::is_basis_upper)
130+
.value("is_basis_zero", NodeBipartiteObs::VariableFeatures ::is_basis_zero);
131131

132132
py::enum_<NodeBipartiteObs::RowFeatures>(node_bipartite_obs, "RowFeatures")
133133
.value("bias", NodeBipartiteObs::RowFeatures::bias)

python/tests/test_observation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@ def test_NodeBipartite_observation(model):
9191
"""Observation of NodeBipartite is a type with array attributes."""
9292
obs = make_obs(ecole.observation.NodeBipartite(), model)
9393
assert isinstance(obs, ecole.observation.NodeBipartiteObs)
94-
assert_array(obs.column_features, ndim=2)
94+
assert_array(obs.variable_features, ndim=2)
9595
assert_array(obs.row_features, ndim=2)
9696
assert_array(obs.edge_features.values)
9797
assert_array(obs.edge_features.indices, ndim=2, dtype=np.uint64)
9898

9999
# Check that there are enums describing feeatures
100-
assert len(obs.ColumnFeatures.__members__) == obs.column_features.shape[1]
100+
assert len(obs.VariableFeatures.__members__) == obs.variable_features.shape[1]
101101
assert len(obs.RowFeatures.__members__) == obs.row_features.shape[1]
102102

103103

0 commit comments

Comments
 (0)