11#include < catch2/catch.hpp>
2+ #include < range/v3/view/enumerate.hpp>
3+ #include < range/v3/view/transform.hpp>
4+ #include < scip/scip.h>
5+ #include < xtensor/xindex_view.hpp>
26#include < xtensor/xmath.hpp>
7+ #include < xtensor/xtensor.hpp>
38#include < xtensor/xview.hpp>
49
510#include " ecole/observation/khalil-2016.hpp"
11+ #include " ecole/tweak/range.hpp"
612
713#include " conftest.hpp"
814#include " observation/unit-tests.hpp"
915
16+ namespace views = ranges::views;
17+
1018using namespace ecole ;
1119
1220TEST_CASE (" Khalil2016 unit tests" , " [unit][obs]" ) {
@@ -20,6 +28,16 @@ auto in_interval(Tensor const& tensor, T const& lower, T const& upper) {
2028 return (lower <= tensor) && (tensor <= upper);
2129}
2230
31+ /* * Get the features of the pseudo candidate only. */
32+ template <typename Tensor, typename Range>
33+ auto obs_pseudo_cands (Tensor const & obs_features, Range const & pseudo_cands_idx) -> Tensor {
34+ auto filtered_features = Tensor::from_shape ({pseudo_cands_idx.size (), obs_features.shape ()[1 ]});
35+ for (auto const [idx, var_idx] : views::enumerate (pseudo_cands_idx)) {
36+ xt::row (filtered_features, static_cast <std::ptrdiff_t >(idx)) = xt::row (obs_features, var_idx);
37+ }
38+ return filtered_features;
39+ }
40+
2341TEST_CASE (" Khalil2016 return correct observation" , " [obs]" ) {
2442 using Features = observation::Khalil2016Obs::Features;
2543
@@ -34,20 +52,23 @@ TEST_CASE("Khalil2016 return correct observation", "[obs]") {
3452
3553 SECTION (" Observation features has correct shape" ) {
3654 auto const & obs = optional_obs.value ();
37- auto const branch_cands = pseudo ? model.pseudo_branch_cands () : model.lp_branch_cands ();
38- REQUIRE (obs.features .shape (0 ) == branch_cands.size ());
55+ REQUIRE (obs.features .shape (0 ) == model.variables ().size ());
3956 REQUIRE (obs.features .shape (1 ) == observation::Khalil2016Obs::n_features);
4057 }
4158
42- SECTION (" No features are NaN or infinite" ) {
43- auto const & obs = optional_obs.value ();
44- REQUIRE_FALSE (xt::any (xt::isnan (obs.features )));
45- REQUIRE_FALSE (xt::any (xt::isinf (obs.features )));
46- }
47-
4859 SECTION (" Observation has correct values" ) {
4960 auto const & obs = optional_obs.value ();
50- auto col = [&obs](auto feat) { return xt::col (obs.features , static_cast <std::ptrdiff_t >(feat)); };
61+ auto const branch_cands = pseudo ? model.pseudo_branch_cands () : model.lp_branch_cands ();
62+ auto obs_pseudo = obs_pseudo_cands (obs.features , views::transform (branch_cands, SCIPvarGetProbindex));
63+ auto col = [&obs_pseudo](auto feat) { return xt::col (obs_pseudo, static_cast <std::ptrdiff_t >(feat)); };
64+
65+ SECTION (" No pseudo_candidate features are NaN or infinite" ) {
66+ for (auto * var : branch_cands) {
67+ auto const var_idx = SCIPvarGetProbindex (var);
68+ REQUIRE_FALSE (xt::any (xt::isnan (xt::row (obs.features , var_idx))));
69+ REQUIRE_FALSE (xt::any (xt::isinf (xt::row (obs.features , var_idx))));
70+ }
71+ }
5172
5273 SECTION (" Objective function coefficients" ) {
5374 REQUIRE (xt::all (col (Features::obj_coef_pos_part) >= 0 ));
0 commit comments