@@ -554,33 +554,31 @@ void set_dynamic_features(
554554template <typename Tensor>
555555void set_precomputed_static_features (
556556 Tensor&& out,
557- SCIP_VAR* const var ,
557+ std:: size_t var_idx ,
558558 xt::xtensor<value_type, 2 > const & static_features) {
559559
560- auto const col_idx = static_cast <std::ptrdiff_t >(SCIPcolGetIndex (SCIPvarGetCol (var)));
561560 using namespace xt ::placeholders;
562- xt::view (out, xt::range (_, Khalil2016Obs::n_static_features)) = xt::row (static_features, col_idx);
561+ xt::view (out, xt::range (_, Khalil2016Obs::n_static_features)) =
562+ xt::row (static_features, static_cast <std::ptrdiff_t >(var_idx));
563563}
564564
565565/* *****************************
566566 * Main extraction function *
567567 ******************************/
568568
569- auto extract_all_features (scip::Model& model, xt::xtensor<value_type, 2 > const & static_features) {
570- xt::xtensor<value_type, 2 > observation{
571- {model. pseudo_branch_cands (). size (), Khalil2016Obs::n_features},
572- std::nan ( " " ),
573- };
569+ auto extract_all_features (scip::Model& model, bool pseudo, xt::xtensor<value_type, 2 > const & static_features) {
570+ auto const branch_cands = pseudo ? model. pseudo_branch_cands () : model. lp_branch_cands ();
571+ auto const n_branch_cands = branch_cands. size ();
572+
573+ auto observation = xt::xtensor<value_type, 2 >{{n_branch_cands, Khalil2016Obs::n_features}, std::nan ( " " ) };
574574
575575 auto * const scip = model.get_scip_ptr ();
576576 auto const active_rows_weights = stats_for_active_constraint_coefficients_weights (model);
577577
578- auto const pseudo_branch_cands = model.pseudo_branch_cands ();
579- auto const n_pseudo_branch_cands = pseudo_branch_cands.size ();
580- for (std::size_t var_idx = 0 ; var_idx < n_pseudo_branch_cands; ++var_idx) {
581- auto * const var = pseudo_branch_cands[var_idx];
578+ for (std::size_t var_idx = 0 ; var_idx < n_branch_cands; ++var_idx) {
579+ auto * const var = branch_cands[var_idx];
582580 auto features = xt::row (observation, static_cast <std::ptrdiff_t >(var_idx));
583- set_precomputed_static_features (features, var , static_features);
581+ set_precomputed_static_features (features, var_idx , static_features);
584582 set_dynamic_features (features, scip, var, active_rows_weights);
585583 }
586584
@@ -598,6 +596,8 @@ auto is_on_root_node(scip::Model& model) -> bool {
598596 * Observation extracting function *
599597 *************************************/
600598
599+ Khalil2016::Khalil2016 (bool pseudo_candidates_) noexcept : pseudo_candidates(pseudo_candidates_) {}
600+
601601void Khalil2016::before_reset (scip::Model& /* model */ ) {
602602 static_features = decltype (static_features){};
603603}
@@ -607,7 +607,7 @@ auto Khalil2016::extract(scip::Model& model, bool /* done */) -> std::optional<K
607607 if (is_on_root_node (model)) {
608608 static_features = extract_static_features (model);
609609 }
610- return {{extract_all_features (model, static_features)}};
610+ return {{extract_all_features (model, pseudo_candidates, static_features)}};
611611 }
612612 return {};
613613}
0 commit comments