Skip to content

Commit 61f1239

Browse files
committed
Refactor variable iteration in Khalil2016
1 parent fb250eb commit 61f1239

File tree

4 files changed

+60
-27
lines changed

4 files changed

+60
-27
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include <utility>
4+
5+
#include <nonstd/span.hpp>
6+
#include <range/v3/range_fwd.hpp>
7+
8+
/**
9+
* Tell the range library that `nonstd::span` is a view type.
10+
*
11+
* See `Rvalue Ranges and Views in C++20 <https://tristanbrindle.com/posts/rvalue-ranges-and-views>`_
12+
* FIXME no longer needed when switching to C++20 ``std::span``.
13+
* */
14+
namespace ranges {
15+
template <typename T, std::size_t Extent> inline constexpr bool enable_borrowed_range<nonstd::span<T, Extent>> = true;
16+
} // namespace ranges

libecole/src/observation/khalil-2016.cpp

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -551,15 +551,10 @@ void set_dynamic_features(
551551
* The static features have been computed for all LP columns and stored in the order of `LPcolumns`.
552552
* We need to find the one associated with the given variable.
553553
*/
554-
template <typename Tensor>
555-
void set_precomputed_static_features(
556-
Tensor&& out,
557-
std::size_t var_idx,
558-
xt::xtensor<value_type, 2> const& static_features) {
559-
554+
template <typename TensorOut, typename TensorIn>
555+
void set_precomputed_static_features(TensorOut&& var_features, TensorIn const& var_static_features) {
560556
using namespace xt::placeholders;
561-
xt::view(out, xt::range(_, Khalil2016Obs::n_static_features)) =
562-
xt::row(static_features, static_cast<std::ptrdiff_t>(var_idx));
557+
xt::view(var_features, xt::range(_, Khalil2016Obs::n_static_features)) = var_static_features;
563558
}
564559

565560
/******************************
@@ -568,18 +563,17 @@ void set_precomputed_static_features(
568563

569564
auto extract_all_features(scip::Model& model, bool pseudo, xt::xtensor<value_type, 2> const& static_features) {
570565
auto const branch_cands = pseudo ? model.pseudo_branch_cands() : model.lp_branch_cands();
571-
auto const n_branch_cands = branch_cands.size();
572-
573-
auto observation = xt::xtensor<value_type, 2>{{n_branch_cands, Khalil2016Obs::n_features}, std::nan("")};
566+
auto observation = xt::xtensor<value_type, 2>{{model.variables().size(), Khalil2016Obs::n_features}, std::nan("")};
574567

575568
auto* const scip = model.get_scip_ptr();
576569
auto const active_rows_weights = stats_for_active_constraint_coefficients_weights(model);
577570

578-
for (std::size_t var_idx = 0; var_idx < n_branch_cands; ++var_idx) {
579-
auto* const var = branch_cands[var_idx];
580-
auto features = xt::row(observation, static_cast<std::ptrdiff_t>(var_idx));
581-
set_precomputed_static_features(features, var_idx, static_features);
582-
set_dynamic_features(features, scip, var, active_rows_weights);
571+
for (auto* var : branch_cands) {
572+
auto const var_idx = SCIPvarGetProbindex(var);
573+
auto var_features = xt::row(observation, var_idx);
574+
auto var_static_features = xt::row(static_features, var_idx);
575+
set_precomputed_static_features(var_features, var_static_features);
576+
set_dynamic_features(var_features, scip, var, active_rows_weights);
583577
}
584578

585579
return observation;

libecole/tests/src/observation/test-khalil-2016.cpp

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
#include <catch2/catch.hpp>
2+
#include <range/v3/view/enumerate.hpp>
3+
#include <range/v3/view/transform.hpp>
4+
#include <scip/scip.h>
5+
#include <xtensor/xindex_view.hpp>
26
#include <xtensor/xmath.hpp>
7+
#include <xtensor/xtensor.hpp>
38
#include <xtensor/xview.hpp>
49

510
#include "ecole/observation/khalil-2016.hpp"
11+
#include "ecole/tweak/range.hpp"
612

713
#include "conftest.hpp"
814
#include "observation/unit-tests.hpp"
915

16+
namespace views = ranges::views;
17+
1018
using namespace ecole;
1119

1220
TEST_CASE("Khalil2016 unit tests", "[unit][obs]") {
@@ -20,6 +28,16 @@ auto in_interval(Tensor const& tensor, T const& lower, T const& upper) {
2028
return (lower <= tensor) && (tensor <= upper);
2129
}
2230

31+
/** Get the features of the pseudo candidate only. */
32+
template <typename Tensor, typename Range>
33+
auto obs_pseudo_cands(Tensor const& obs_features, Range const& pseudo_cands_idx) -> Tensor {
34+
auto filtered_features = Tensor::from_shape({pseudo_cands_idx.size(), obs_features.shape()[1]});
35+
for (auto const [idx, var_idx] : views::enumerate(pseudo_cands_idx)) {
36+
xt::row(filtered_features, static_cast<std::ptrdiff_t>(idx)) = xt::row(obs_features, var_idx);
37+
}
38+
return filtered_features;
39+
}
40+
2341
TEST_CASE("Khalil2016 return correct observation", "[obs]") {
2442
using Features = observation::Khalil2016Obs::Features;
2543

@@ -34,20 +52,23 @@ TEST_CASE("Khalil2016 return correct observation", "[obs]") {
3452

3553
SECTION("Observation features has correct shape") {
3654
auto const& obs = optional_obs.value();
37-
auto const branch_cands = pseudo ? model.pseudo_branch_cands() : model.lp_branch_cands();
38-
REQUIRE(obs.features.shape(0) == branch_cands.size());
55+
REQUIRE(obs.features.shape(0) == model.variables().size());
3956
REQUIRE(obs.features.shape(1) == observation::Khalil2016Obs::n_features);
4057
}
4158

42-
SECTION("No features are NaN or infinite") {
43-
auto const& obs = optional_obs.value();
44-
REQUIRE_FALSE(xt::any(xt::isnan(obs.features)));
45-
REQUIRE_FALSE(xt::any(xt::isinf(obs.features)));
46-
}
47-
4859
SECTION("Observation has correct values") {
4960
auto const& obs = optional_obs.value();
50-
auto col = [&obs](auto feat) { return xt::col(obs.features, static_cast<std::ptrdiff_t>(feat)); };
61+
auto const branch_cands = pseudo ? model.pseudo_branch_cands() : model.lp_branch_cands();
62+
auto obs_pseudo = obs_pseudo_cands(obs.features, views::transform(branch_cands, SCIPvarGetProbindex));
63+
auto col = [&obs_pseudo](auto feat) { return xt::col(obs_pseudo, static_cast<std::ptrdiff_t>(feat)); };
64+
65+
SECTION("No pseudo_candidate features are NaN or infinite") {
66+
for (auto* var : branch_cands) {
67+
auto const var_idx = SCIPvarGetProbindex(var);
68+
REQUIRE_FALSE(xt::any(xt::isnan(xt::row(obs.features, var_idx))));
69+
REQUIRE_FALSE(xt::any(xt::isinf(xt::row(obs.features, var_idx))));
70+
}
71+
}
5172

5273
SECTION("Objective function coefficients") {
5374
REQUIRE(xt::all(col(Features::obj_coef_pos_part) >= 0));

python/src/ecole/core/observation.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,10 @@ void bind_submodule(py::module_ const& m) {
263263
auto_class<Khalil2016Obs>(m, "Khalil2016Obs", R"(
264264
Branching candidates features from Khalil et al. (2016).
265265
266-
The observation is a matrix where rows represent pseudo branching candidates and columns
267-
represent features related to these variables.
266+
The observation is a matrix where rows represent all variables and columns represent features related
267+
to these variables.
268+
Only rows representing pseudo branching candidate contain meaningful observation, other rows are filled with
269+
``NaN``.
268270
See [Khalil2016]_ for a complete reference on this observation function.
269271
270272
The first :py:attr:`Khalil2016Obs.n_static_features` are static (they do not change through the solving

0 commit comments

Comments
 (0)