11#include < catch2/catch.hpp>
2+ #include < range/v3/view/enumerate.hpp>
3+ #include < range/v3/view/transform.hpp>
4+ #include < scip/scip.h>
5+ #include < xtensor/xindex_view.hpp>
26#include < xtensor/xmath.hpp>
7+ #include < xtensor/xtensor.hpp>
38#include < xtensor/xview.hpp>
49
510#include " ecole/observation/khalil-2016.hpp"
11+ #include " ecole/tweak/range.hpp"
612
713#include " conftest.hpp"
814#include " observation/unit-tests.hpp"
915
16+ namespace views = ranges::views;
17+
1018using namespace ecole ;
1119
1220TEST_CASE (" Khalil2016 unit tests" , " [unit][obs]" ) {
@@ -19,6 +27,16 @@ auto in_interval(Tensor const& tensor, T const& lower, T const& upper) {
1927 return (lower <= tensor) && (tensor <= upper);
2028}
2129
30+ /* * Get the features of the pseudo candidate only. */
31+ template <typename Tensor, typename Range>
32+ auto obs_pseudo_cands (Tensor const & obs_features, Range const & pseudo_cands_idx) -> Tensor {
33+ auto filtered_features = Tensor::from_shape ({pseudo_cands_idx.size (), obs_features.shape ()[1 ]});
34+ for (auto const [idx, var_idx] : views::enumerate (pseudo_cands_idx)) {
35+ xt::row (filtered_features, static_cast <std::ptrdiff_t >(idx)) = xt::row (obs_features, var_idx);
36+ }
37+ return filtered_features;
38+ }
39+
2240TEST_CASE (" Khalil2016 return correct observation" , " [obs]" ) {
2341 using Features = observation::Khalil2016Obs::Features;
2442
@@ -32,19 +50,23 @@ TEST_CASE("Khalil2016 return correct observation", "[obs]") {
3250
3351 SECTION (" Observation features has correct shape" ) {
3452 auto const & obs = optional_obs.value ();
35- REQUIRE (obs.features .shape (0 ) == model.pseudo_branch_cands ().size ());
53+ REQUIRE (obs.features .shape (0 ) == model.variables ().size ());
3654 REQUIRE (obs.features .shape (1 ) == observation::Khalil2016Obs::n_features);
3755 }
3856
39- SECTION (" No features are NaN or infinite" ) {
40- auto const & obs = optional_obs.value ();
41- REQUIRE_FALSE (xt::any (xt::isnan (obs.features )));
42- REQUIRE_FALSE (xt::any (xt::isinf (obs.features )));
43- }
44-
4557 SECTION (" Observation has correct values" ) {
4658 auto const & obs = optional_obs.value ();
47- auto col = [&obs](auto feat) { return xt::col (obs.features , static_cast <std::ptrdiff_t >(feat)); };
59+ auto obs_pseudo =
60+ obs_pseudo_cands (obs.features , views::transform (model.pseudo_branch_cands (), SCIPvarGetProbindex));
61+ auto col = [&obs_pseudo](auto feat) { return xt::col (obs_pseudo, static_cast <std::ptrdiff_t >(feat)); };
62+
63+ SECTION (" No pseudo_candidate features are NaN or infinite" ) {
64+ for (auto * var : model.pseudo_branch_cands ()) {
65+ auto const var_idx = SCIPvarGetProbindex (var);
66+ REQUIRE_FALSE (xt::any (xt::isnan (xt::row (obs.features , var_idx))));
67+ REQUIRE_FALSE (xt::any (xt::isinf (xt::row (obs.features , var_idx))));
68+ }
69+ }
4870
4971 SECTION (" Objective function coefficients" ) {
5072 REQUIRE (xt::all (col (Features::obj_coef_pos_part) >= 0 ));
0 commit comments