@@ -554,34 +554,31 @@ void set_dynamic_features(
554554template <typename Tensor>
555555void set_precomputed_static_features (
556556 Tensor&& out,
557- SCIP_VAR* const var ,
557+ std:: size_t var_idx ,
558558 xt::xtensor<value_type, 2 > const & static_features) {
559559
560- auto const col_idx = static_cast <std::ptrdiff_t >(SCIPcolGetIndex (SCIPvarGetCol (var)));
561560 using namespace xt ::placeholders;
562- xt::view (out, xt::range (_, Khalil2016Obs::n_static_features)) = xt::row (static_features, col_idx);
561+ xt::view (out, xt::range (_, Khalil2016Obs::n_static_features)) =
562+ xt::row (static_features, static_cast <std::ptrdiff_t >(var_idx));
563563}
564564
565565/* *****************************
566566 * Main extraction function *
567567 ******************************/
568568
569569auto extract_all_features (scip::Model& model, bool pseudo, xt::xtensor<value_type, 2 > const & static_features) {
570- xt::xtensor<value_type, 2 > observation{
571- {model. pseudo_branch_cands (). size (), Khalil2016Obs::n_features},
572- std::nan ( " " ),
573- };
570+ auto const branch_cands = pseudo ? model. pseudo_branch_cands () : model. lp_branch_cands ();
571+ auto const n_branch_cands = branch_cands. size ();
572+
573+ auto observation = xt::xtensor<value_type, 2 >{{n_branch_cands, Khalil2016Obs::n_features}, std::nan ( " " ) };
574574
575575 auto * const scip = model.get_scip_ptr ();
576576 auto const active_rows_weights = stats_for_active_constraint_coefficients_weights (model);
577577
578- auto const branch_cands = pseudo ? model.pseudo_branch_cands () : model.lp_branch_cands ();
579-
580- auto const n_branch_cands = branch_cands.size ();
581578 for (std::size_t var_idx = 0 ; var_idx < n_branch_cands; ++var_idx) {
582579 auto * const var = branch_cands[var_idx];
583580 auto features = xt::row (observation, static_cast <std::ptrdiff_t >(var_idx));
584- set_precomputed_static_features (features, var , static_features);
581+ set_precomputed_static_features (features, var_idx , static_features);
585582 set_dynamic_features (features, scip, var, active_rows_weights);
586583 }
587584
0 commit comments