Skip to content

Commit aa49de1

Browse files
authored
Merge pull request #234 from benoitsteiner/master
Fix for issue 224
2 parents 77911be + aa7dcbb commit aa49de1

File tree

4 files changed

+32
-18
lines changed

4 files changed

+32
-18
lines changed

libecole/include/ecole/observation/khalil-2016.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,14 @@ struct Khalil2016Obs {
108108

109109
class Khalil2016 : public ObservationFunction<std::optional<Khalil2016Obs>> {
110110
public:
111+
Khalil2016(bool pseudo_candidates = false) noexcept;
112+
111113
void before_reset(scip::Model& model) override;
112114

113115
std::optional<Khalil2016Obs> extract(scip::Model& model, bool done) override;
114116

115117
private:
118+
bool pseudo_candidates;
116119
xt::xtensor<double, 2> static_features;
117120
};
118121

libecole/src/observation/khalil-2016.cpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -554,33 +554,31 @@ void set_dynamic_features(
554554
template <typename Tensor>
555555
void set_precomputed_static_features(
556556
Tensor&& out,
557-
SCIP_VAR* const var,
557+
std::size_t var_idx,
558558
xt::xtensor<value_type, 2> const& static_features) {
559559

560-
auto const col_idx = static_cast<std::ptrdiff_t>(SCIPcolGetIndex(SCIPvarGetCol(var)));
561560
using namespace xt::placeholders;
562-
xt::view(out, xt::range(_, Khalil2016Obs::n_static_features)) = xt::row(static_features, col_idx);
561+
xt::view(out, xt::range(_, Khalil2016Obs::n_static_features)) =
562+
xt::row(static_features, static_cast<std::ptrdiff_t>(var_idx));
563563
}
564564

565565
/******************************
566566
* Main extraction function *
567567
******************************/
568568

569-
auto extract_all_features(scip::Model& model, xt::xtensor<value_type, 2> const& static_features) {
570-
xt::xtensor<value_type, 2> observation{
571-
{model.pseudo_branch_cands().size(), Khalil2016Obs::n_features},
572-
std::nan(""),
573-
};
569+
auto extract_all_features(scip::Model& model, bool pseudo, xt::xtensor<value_type, 2> const& static_features) {
570+
auto const branch_cands = pseudo ? model.pseudo_branch_cands() : model.lp_branch_cands();
571+
auto const n_branch_cands = branch_cands.size();
572+
573+
auto observation = xt::xtensor<value_type, 2>{{n_branch_cands, Khalil2016Obs::n_features}, std::nan("")};
574574

575575
auto* const scip = model.get_scip_ptr();
576576
auto const active_rows_weights = stats_for_active_constraint_coefficients_weights(model);
577577

578-
auto const pseudo_branch_cands = model.pseudo_branch_cands();
579-
auto const n_pseudo_branch_cands = pseudo_branch_cands.size();
580-
for (std::size_t var_idx = 0; var_idx < n_pseudo_branch_cands; ++var_idx) {
581-
auto* const var = pseudo_branch_cands[var_idx];
578+
for (std::size_t var_idx = 0; var_idx < n_branch_cands; ++var_idx) {
579+
auto* const var = branch_cands[var_idx];
582580
auto features = xt::row(observation, static_cast<std::ptrdiff_t>(var_idx));
583-
set_precomputed_static_features(features, var, static_features);
581+
set_precomputed_static_features(features, var_idx, static_features);
584582
set_dynamic_features(features, scip, var, active_rows_weights);
585583
}
586584

@@ -598,6 +596,8 @@ auto is_on_root_node(scip::Model& model) -> bool {
598596
* Observation extracting function *
599597
*************************************/
600598

599+
Khalil2016::Khalil2016(bool pseudo_candidates_) noexcept : pseudo_candidates(pseudo_candidates_) {}
600+
601601
void Khalil2016::before_reset(scip::Model& /* model */) {
602602
static_features = decltype(static_features){};
603603
}
@@ -607,7 +607,7 @@ auto Khalil2016::extract(scip::Model& model, bool /* done */) -> std::optional<K
607607
if (is_on_root_node(model)) {
608608
static_features = extract_static_features(model);
609609
}
610-
return {{extract_all_features(model, static_features)}};
610+
return {{extract_all_features(model, pseudo_candidates, static_features)}};
611611
}
612612
return {};
613613
}

libecole/tests/src/observation/test-khalil-2016.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
using namespace ecole;
1111

1212
TEST_CASE("Khalil2016 unit tests", "[unit][obs]") {
13-
observation::unit_tests(observation::Khalil2016{});
13+
auto const pseudo = GENERATE(true, false);
14+
observation::unit_tests(observation::Khalil2016{pseudo});
1415
}
1516

1617
template <typename Tensor, typename T = typename Tensor::value_type>
@@ -22,7 +23,8 @@ auto in_interval(Tensor const& tensor, T const& lower, T const& upper) {
2223
TEST_CASE("Khalil2016 return correct observation", "[obs]") {
2324
using Features = observation::Khalil2016Obs::Features;
2425

25-
auto obs_func = observation::Khalil2016{};
26+
auto const pseudo = GENERATE(true, false);
27+
auto obs_func = observation::Khalil2016{pseudo};
2628
auto model = get_model();
2729
obs_func.before_reset(model);
2830
advance_to_stage(model, SCIP_STAGE_SOLVING);
@@ -32,7 +34,8 @@ TEST_CASE("Khalil2016 return correct observation", "[obs]") {
3234

3335
SECTION("Observation features has correct shape") {
3436
auto const& obs = optional_obs.value();
35-
REQUIRE(obs.features.shape(0) == model.pseudo_branch_cands().size());
37+
auto const branch_cands = pseudo ? model.pseudo_branch_cands() : model.lp_branch_cands();
38+
REQUIRE(obs.features.shape(0) == branch_cands.size());
3639
REQUIRE(obs.features.shape(1) == observation::Khalil2016Obs::n_features);
3740
}
3841

python/src/ecole/core/observation.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,15 @@ void bind_submodule(py::module_ const& m) {
364364
365365
This observation function extract structured :py:class:`Khalil2016Obs`.
366366
)");
367-
khalil2016.def(py::init<>());
367+
khalil2016.def(py::init<bool>(), py::arg("pseudo_candidates") = false, R"(
368+
Create new observation.
369+
370+
Parameters
371+
----------
372+
pseudo_candidates:
373+
Whether the pseudo branching variable candidates (``SCIPgetPseudoBranchCands``)
374+
or LP branching variable candidates (``SCIPgetPseudoBranchCands``) are observed.
375+
)");
368376
def_before_reset(khalil2016, R"(Reset static features cache.)");
369377
def_extract(khalil2016, "Extract the observation matrix.");
370378

0 commit comments

Comments
 (0)