Skip to content

Commit aa7dcbb

Browse files
committed
Fix Khalil2016 tests
1 parent 3657302 commit aa7dcbb

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

libecole/src/observation/khalil-2016.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -554,34 +554,31 @@ void set_dynamic_features(
554554
template <typename Tensor>
555555
void set_precomputed_static_features(
556556
Tensor&& out,
557-
SCIP_VAR* const var,
557+
std::size_t var_idx,
558558
xt::xtensor<value_type, 2> const& static_features) {
559559

560-
auto const col_idx = static_cast<std::ptrdiff_t>(SCIPcolGetIndex(SCIPvarGetCol(var)));
561560
using namespace xt::placeholders;
562-
xt::view(out, xt::range(_, Khalil2016Obs::n_static_features)) = xt::row(static_features, col_idx);
561+
xt::view(out, xt::range(_, Khalil2016Obs::n_static_features)) =
562+
xt::row(static_features, static_cast<std::ptrdiff_t>(var_idx));
563563
}
564564

565565
/******************************
566566
* Main extraction function *
567567
******************************/
568568

569569
auto extract_all_features(scip::Model& model, bool pseudo, xt::xtensor<value_type, 2> const& static_features) {
570-
xt::xtensor<value_type, 2> observation{
571-
{model.pseudo_branch_cands().size(), Khalil2016Obs::n_features},
572-
std::nan(""),
573-
};
570+
auto const branch_cands = pseudo ? model.pseudo_branch_cands() : model.lp_branch_cands();
571+
auto const n_branch_cands = branch_cands.size();
572+
573+
auto observation = xt::xtensor<value_type, 2>{{n_branch_cands, Khalil2016Obs::n_features}, std::nan("")};
574574

575575
auto* const scip = model.get_scip_ptr();
576576
auto const active_rows_weights = stats_for_active_constraint_coefficients_weights(model);
577577

578-
auto const branch_cands = pseudo ? model.pseudo_branch_cands() : model.lp_branch_cands();
579-
580-
auto const n_branch_cands = branch_cands.size();
581578
for (std::size_t var_idx = 0; var_idx < n_branch_cands; ++var_idx) {
582579
auto* const var = branch_cands[var_idx];
583580
auto features = xt::row(observation, static_cast<std::ptrdiff_t>(var_idx));
584-
set_precomputed_static_features(features, var, static_features);
581+
set_precomputed_static_features(features, var_idx, static_features);
585582
set_dynamic_features(features, scip, var, active_rows_weights);
586583
}
587584

libecole/tests/src/observation/test-khalil-2016.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
using namespace ecole;
1111

1212
TEST_CASE("Khalil2016 unit tests", "[unit][obs]") {
13-
observation::unit_tests(observation::Khalil2016{});
13+
auto const pseudo = GENERATE(true, false);
14+
observation::unit_tests(observation::Khalil2016{pseudo});
1415
}
1516

1617
template <typename Tensor, typename T = typename Tensor::value_type>
@@ -22,7 +23,8 @@ auto in_interval(Tensor const& tensor, T const& lower, T const& upper) {
2223
TEST_CASE("Khalil2016 return correct observation", "[obs]") {
2324
using Features = observation::Khalil2016Obs::Features;
2425

25-
auto obs_func = observation::Khalil2016{};
26+
auto const pseudo = GENERATE(true, false);
27+
auto obs_func = observation::Khalil2016{pseudo};
2628
auto model = get_model();
2729
obs_func.before_reset(model);
2830
advance_to_stage(model, SCIP_STAGE_SOLVING);
@@ -32,7 +34,8 @@ TEST_CASE("Khalil2016 return correct observation", "[obs]") {
3234

3335
SECTION("Observation features has correct shape") {
3436
auto const& obs = optional_obs.value();
35-
REQUIRE(obs.features.shape(0) == model.pseudo_branch_cands().size());
37+
auto const branch_cands = pseudo ? model.pseudo_branch_cands() : model.lp_branch_cands();
38+
REQUIRE(obs.features.shape(0) == branch_cands.size());
3639
REQUIRE(obs.features.shape(1) == observation::Khalil2016Obs::n_features);
3740
}
3841

0 commit comments

Comments
 (0)