Skip to content

Commit 6658e93

Browse files
committed
Code quality improvement in string branching scores
1 parent e3dd2ea commit 6658e93

File tree

1 file changed

+51
-60
lines changed

1 file changed

+51
-60
lines changed

libecole/src/observation/strongbranchingscores.cpp

Lines changed: 51 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,68 +17,59 @@ StrongBranchingScores::StrongBranchingScores(bool pseudo_candidates_) :
1717

1818
nonstd::optional<xt::xtensor<double, 1>>
1919
StrongBranchingScores::obtain_observation(scip::Model& model) {
20-
21-
if (model.get_stage() == SCIP_STAGE_SOLVING) {
22-
23-
SCIP* scip = model.get_scip_ptr();
24-
25-
/* store original SCIP parameters */
26-
auto const integralcands = model.get_param<bool>("branching/vanillafullstrong/integralcands");
27-
auto const scoreall = model.get_param<bool>("branching/vanillafullstrong/scoreall");
28-
auto const collectscores = model.get_param<bool>("branching/vanillafullstrong/collectscores");
29-
auto const donotbranch = model.get_param<bool>("branching/vanillafullstrong/donotbranch");
30-
auto const idempotent = model.get_param<bool>("branching/vanillafullstrong/idempotent");
31-
32-
/* set parameters for vanilla full strong branching */
33-
if (pseudo_candidates) {
34-
model.set_param("branching/vanillafullstrong/integralcands", true);
35-
} else {
36-
model.set_param("branching/vanillafullstrong/integralcands", false);
37-
}
38-
model.set_param("branching/vanillafullstrong/scoreall", true);
39-
model.set_param("branching/vanillafullstrong/collectscores", true);
40-
model.set_param("branching/vanillafullstrong/donotbranch", true);
41-
model.set_param("branching/vanillafullstrong/idempotent", true);
42-
43-
/* execute vanilla full strong branching */
44-
SCIP_BRANCHRULE* branchrule = SCIPfindBranchrule(scip, "vanillafullstrong");
45-
SCIP_RESULT result;
46-
scip::call(branchrule->branchexeclp, scip, branchrule, false, &result);
47-
assert(result == SCIP_DIDNOTRUN);
48-
49-
/* get vanilla full strong branching scores */
50-
SCIP_VAR** cands;
51-
SCIP_Real* candscores;
52-
int ncands;
53-
54-
SCIPgetVanillafullstrongData(scip, &cands, &candscores, &ncands, NULL, NULL);
55-
56-
assert(ncands >= 0);
57-
58-
/* restore model parameters */
59-
model.set_param("branching/vanillafullstrong/integralcands", integralcands);
60-
model.set_param("branching/vanillafullstrong/scoreall", scoreall);
61-
model.set_param("branching/vanillafullstrong/collectscores", collectscores);
62-
model.set_param("branching/vanillafullstrong/donotbranch", donotbranch);
63-
model.set_param("branching/vanillafullstrong/idempotent", idempotent);
64-
65-
/* Store strong branching scores in tensor */
66-
auto const num_lp_columns = static_cast<std::size_t>(SCIPgetNLPCols(scip));
67-
xt::xtensor<double, 1> strong_branching_scores({num_lp_columns}, std::nan(""));
68-
69-
SCIP_COL* col;
70-
int lp_index;
71-
for (int i = 0; i < ncands; i++) {
72-
col = SCIPvarGetCol(cands[i]);
73-
lp_index = SCIPcolGetLPPos(col);
74-
strong_branching_scores(lp_index) = static_cast<double>(candscores[i]);
75-
}
76-
77-
return strong_branching_scores;
78-
79-
} else {
20+
if (model.get_stage() != SCIP_STAGE_SOLVING) {
8021
return {};
8122
}
23+
24+
auto const scip = model.get_scip_ptr();
25+
26+
/* store original SCIP parameters */
27+
auto const integralcands = model.get_param<bool>("branching/vanillafullstrong/integralcands");
28+
auto const scoreall = model.get_param<bool>("branching/vanillafullstrong/scoreall");
29+
auto const collectscores = model.get_param<bool>("branching/vanillafullstrong/collectscores");
30+
auto const donotbranch = model.get_param<bool>("branching/vanillafullstrong/donotbranch");
31+
auto const idempotent = model.get_param<bool>("branching/vanillafullstrong/idempotent");
32+
33+
/* set parameters for vanilla full strong branching */
34+
model.set_param("branching/vanillafullstrong/integralcands", pseudo_candidates);
35+
model.set_param("branching/vanillafullstrong/scoreall", true);
36+
model.set_param("branching/vanillafullstrong/collectscores", true);
37+
model.set_param("branching/vanillafullstrong/donotbranch", true);
38+
model.set_param("branching/vanillafullstrong/idempotent", true);
39+
40+
/* execute vanilla full strong branching */
41+
auto branchrule = SCIPfindBranchrule(scip, "vanillafullstrong");
42+
SCIP_RESULT result;
43+
scip::call(branchrule->branchexeclp, scip, branchrule, false, &result);
44+
assert(result == SCIP_DIDNOTRUN);
45+
46+
/* get vanilla full strong branching scores */
47+
SCIP_VAR** cands;
48+
SCIP_Real* candscores;
49+
int ncands;
50+
51+
SCIPgetVanillafullstrongData(scip, &cands, &candscores, &ncands, NULL, NULL);
52+
53+
assert(ncands >= 0);
54+
55+
/* restore model parameters */
56+
model.set_param("branching/vanillafullstrong/integralcands", integralcands);
57+
model.set_param("branching/vanillafullstrong/scoreall", scoreall);
58+
model.set_param("branching/vanillafullstrong/collectscores", collectscores);
59+
model.set_param("branching/vanillafullstrong/donotbranch", donotbranch);
60+
model.set_param("branching/vanillafullstrong/idempotent", idempotent);
61+
62+
/* Store strong branching scores in tensor */
63+
auto const num_lp_columns = static_cast<std::size_t>(SCIPgetNLPCols(scip));
64+
auto strong_branching_scores = xt::xtensor<double, 1>({num_lp_columns}, std::nan(""));
65+
66+
for (std::size_t i = 0; i < static_cast<std::size_t>(ncands); i++) {
67+
auto const col = SCIPvarGetCol(cands[i]);
68+
auto const lp_index = static_cast<std::size_t>(SCIPcolGetLPPos(col));
69+
strong_branching_scores[lp_index] = static_cast<double>(candscores[i]);
70+
}
71+
72+
return strong_branching_scores;
8273
}
8374

8475
} // namespace observation

0 commit comments

Comments
 (0)