Skip to content

Commit ea6822b

Browse files
committed
Refactor variable iteration in Khalil2016
1 parent a0a6e7f commit ea6822b

File tree

4 files changed

+60
-26
lines changed

4 files changed

+60
-26
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include <utility>
4+
5+
#include <nonstd/span.hpp>
6+
#include <range/v3/range_fwd.hpp>
7+
8+
/**
9+
* Tell the range library that `nonstd::span` is a view type.
10+
*
11+
* See `Rvalue Ranges and Views in C++20 <https://tristanbrindle.com/posts/rvalue-ranges-and-views>`_
12+
* FIXME no longer needed when switching to C++20 ``std::span``.
13+
* */
14+
namespace ranges {
15+
template <typename T, std::size_t Extent> inline constexpr bool enable_borrowed_range<nonstd::span<T, Extent>> = true;
16+
} // namespace ranges

libecole/src/observation/khalil-2016.cpp

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -551,15 +551,10 @@ void set_dynamic_features(
551551
* The static features have been computed for all LP columns and stored in the order of `LPcolumns`.
552552
* We need to find the one associated with the given variable.
553553
*/
554-
template <typename Tensor>
555-
void set_precomputed_static_features(
556-
Tensor&& out,
557-
SCIP_VAR* const var,
558-
xt::xtensor<value_type, 2> const& static_features) {
559-
560-
auto const col_idx = static_cast<std::ptrdiff_t>(SCIPcolGetIndex(SCIPvarGetCol(var)));
554+
template <typename TensorOut, typename TensorIn>
555+
void set_precomputed_static_features(TensorOut&& var_features, TensorIn const& var_static_features) {
561556
using namespace xt::placeholders;
562-
xt::view(out, xt::range(_, Khalil2016Obs::n_static_features)) = xt::row(static_features, col_idx);
557+
xt::view(var_features, xt::range(_, Khalil2016Obs::n_static_features)) = var_static_features;
563558
}
564559

565560
/******************************
@@ -568,20 +563,19 @@ void set_precomputed_static_features(
568563

569564
auto extract_all_features(scip::Model& model, xt::xtensor<value_type, 2> const& static_features) {
570565
xt::xtensor<value_type, 2> observation{
571-
{model.pseudo_branch_cands().size(), Khalil2016Obs::n_features},
566+
{model.variables().size(), Khalil2016Obs::n_features},
572567
std::nan(""),
573568
};
574569

575570
auto* const scip = model.get_scip_ptr();
576571
auto const active_rows_weights = stats_for_active_constraint_coefficients_weights(model);
577572

578-
auto const pseudo_branch_cands = model.pseudo_branch_cands();
579-
auto const n_pseudo_branch_cands = pseudo_branch_cands.size();
580-
for (std::size_t var_idx = 0; var_idx < n_pseudo_branch_cands; ++var_idx) {
581-
auto* const var = pseudo_branch_cands[var_idx];
582-
auto features = xt::row(observation, static_cast<std::ptrdiff_t>(var_idx));
583-
set_precomputed_static_features(features, var, static_features);
584-
set_dynamic_features(features, scip, var, active_rows_weights);
573+
for (auto* var : model.pseudo_branch_cands()) {
574+
auto const var_idx = SCIPvarGetProbindex(var);
575+
auto var_features = xt::row(observation, var_idx);
576+
auto var_static_features = xt::row(static_features, var_idx);
577+
set_precomputed_static_features(var_features, var_static_features);
578+
set_dynamic_features(var_features, scip, var, active_rows_weights);
585579
}
586580

587581
return observation;

libecole/tests/src/observation/test-khalil-2016.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
#include <catch2/catch.hpp>
2+
#include <range/v3/view/enumerate.hpp>
3+
#include <range/v3/view/transform.hpp>
4+
#include <scip/scip.h>
5+
#include <xtensor/xindex_view.hpp>
26
#include <xtensor/xmath.hpp>
7+
#include <xtensor/xtensor.hpp>
38
#include <xtensor/xview.hpp>
49

510
#include "ecole/observation/khalil-2016.hpp"
11+
#include "ecole/tweak/range.hpp"
612

713
#include "conftest.hpp"
814
#include "observation/unit-tests.hpp"
915

16+
namespace views = ranges::views;
17+
1018
using namespace ecole;
1119

1220
TEST_CASE("Khalil2016 unit tests", "[unit][obs]") {
@@ -19,6 +27,16 @@ auto in_interval(Tensor const& tensor, T const& lower, T const& upper) {
1927
return (lower <= tensor) && (tensor <= upper);
2028
}
2129

30+
/** Get the features of the pseudo candidate only. */
31+
template <typename Tensor, typename Range>
32+
auto obs_pseudo_cands(Tensor const& obs_features, Range const& pseudo_cands_idx) -> Tensor {
33+
auto filtered_features = Tensor::from_shape({pseudo_cands_idx.size(), obs_features.shape()[1]});
34+
for (auto const [idx, var_idx] : views::enumerate(pseudo_cands_idx)) {
35+
xt::row(filtered_features, static_cast<std::ptrdiff_t>(idx)) = xt::row(obs_features, var_idx);
36+
}
37+
return filtered_features;
38+
}
39+
2240
TEST_CASE("Khalil2016 return correct observation", "[obs]") {
2341
using Features = observation::Khalil2016Obs::Features;
2442

@@ -32,19 +50,23 @@ TEST_CASE("Khalil2016 return correct observation", "[obs]") {
3250

3351
SECTION("Observation features has correct shape") {
3452
auto const& obs = optional_obs.value();
35-
REQUIRE(obs.features.shape(0) == model.pseudo_branch_cands().size());
53+
REQUIRE(obs.features.shape(0) == model.variables().size());
3654
REQUIRE(obs.features.shape(1) == observation::Khalil2016Obs::n_features);
3755
}
3856

39-
SECTION("No features are NaN or infinite") {
40-
auto const& obs = optional_obs.value();
41-
REQUIRE_FALSE(xt::any(xt::isnan(obs.features)));
42-
REQUIRE_FALSE(xt::any(xt::isinf(obs.features)));
43-
}
44-
4557
SECTION("Observation has correct values") {
4658
auto const& obs = optional_obs.value();
47-
auto col = [&obs](auto feat) { return xt::col(obs.features, static_cast<std::ptrdiff_t>(feat)); };
59+
auto obs_pseudo =
60+
obs_pseudo_cands(obs.features, views::transform(model.pseudo_branch_cands(), SCIPvarGetProbindex));
61+
auto col = [&obs_pseudo](auto feat) { return xt::col(obs_pseudo, static_cast<std::ptrdiff_t>(feat)); };
62+
63+
SECTION("No pseudo_candidate features are NaN or infinite") {
64+
for (auto* var : model.pseudo_branch_cands()) {
65+
auto const var_idx = SCIPvarGetProbindex(var);
66+
REQUIRE_FALSE(xt::any(xt::isnan(xt::row(obs.features, var_idx))));
67+
REQUIRE_FALSE(xt::any(xt::isinf(xt::row(obs.features, var_idx))));
68+
}
69+
}
4870

4971
SECTION("Objective function coefficients") {
5072
REQUIRE(xt::all(col(Features::obj_coef_pos_part) >= 0));

python/src/ecole/core/observation.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,10 @@ void bind_submodule(py::module_ const& m) {
263263
auto_class<Khalil2016Obs>(m, "Khalil2016Obs", R"(
264264
Branching candidates features from Khalil et al. (2016).
265265
266-
The observation is a matrix where rows represent pseudo branching candidates and columns
267-
represent features related to these variables.
266+
The observation is a matrix where rows represent all variables and columns represent features related
267+
to these variables.
268+
Only rows representing pseudo branching candidate contain meaningful observation, other rows are filled with
269+
``NaN``.
268270
See [Khalil2016]_ for a complete reference on this observation function.
269271
270272
The first :py:attr:`Khalil2016Obs.n_static_features` are static (they do not change through the solving

0 commit comments

Comments
 (0)