Skip to content

Commit 83e80cf

Browse files
committed
[CP-SAT] more work on hints
1 parent 8e4ce6c commit 83e80cf

File tree

13 files changed

+89
-40
lines changed

13 files changed

+89
-40
lines changed

ortools/sat/cp_model_presolve.cc

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2942,6 +2942,34 @@ bool CpModelPresolver::PresolveSmallLinear(ConstraintProto* ct) {
29422942
return false;
29432943
}
29442944

2945+
namespace {
2946+
// Set the hint in `context` for the variable in `equality` that has no hint, if
2947+
// there is exactly one. Otherwise do nothing.
2948+
void MaybeComputeMissingHint(PresolveContext* context,
2949+
const LinearConstraintProto& equality) {
2950+
DCHECK(equality.domain_size() == 2 &&
2951+
equality.domain(0) == equality.domain(1));
2952+
if (!context->HintIsLoaded()) return;
2953+
int term_with_missing_hint = -1;
2954+
int64_t missing_term_value = equality.domain(0);
2955+
for (int i = 0; i < equality.vars_size(); ++i) {
2956+
if (context->VarHasSolutionHint(equality.vars(i))) {
2957+
missing_term_value -=
2958+
context->SolutionHint(equality.vars(i)) * equality.coeffs(i);
2959+
} else if (term_with_missing_hint == -1) {
2960+
term_with_missing_hint = i;
2961+
} else {
2962+
// More than one variable has a missing hint.
2963+
return;
2964+
}
2965+
}
2966+
if (term_with_missing_hint == -1) return;
2967+
context->SetNewVariableHint(
2968+
equality.vars(term_with_missing_hint),
2969+
missing_term_value / equality.coeffs(term_with_missing_hint));
2970+
}
2971+
} // namespace
2972+
29452973
bool CpModelPresolver::PresolveDiophantine(ConstraintProto* ct) {
29462974
if (ct->constraint_case() != ConstraintProto::kLinear) return false;
29472975
if (ct->linear().vars().size() <= 1) return false;
@@ -3064,6 +3092,15 @@ bool CpModelPresolver::PresolveDiophantine(ConstraintProto* ct) {
30643092
}
30653093
}
30663094
context_->InitializeNewDomains();
3095+
// Scan the new constraints added above in reverse order so that the hint of
3096+
// `new_variables[k]` can be computed from the hint of the existing variables
3097+
// and from the hints of `new_variables[k']`, with k' > k.
3098+
const int num_constraints = context_->working_model->constraints_size();
3099+
for (int i = 0; i < num_replaced_variables; ++i) {
3100+
MaybeComputeMissingHint(
3101+
context_,
3102+
context_->working_model->constraints(num_constraints - 1 - i).linear());
3103+
}
30673104

30683105
if (VLOG_IS_ON(2)) {
30693106
std::string log_eq = absl::StrCat(linear_constraint.domain(0), " = ");
@@ -7457,7 +7494,7 @@ void CpModelPresolver::Probe() {
74577494
namespace {
74587495

74597496
bool FixFromAssignment(const VariablesAssignment& assignment,
7460-
const std::vector<int>& var_mapping,
7497+
absl::Span<const int> var_mapping,
74617498
PresolveContext* context) {
74627499
const int num_vars = assignment.NumberOfVariables();
74637500
for (int i = 0; i < num_vars; ++i) {

ortools/sat/cp_model_solver_helpers.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ IntegerVariable GetOrCreateVariableWithTightBound(
337337
}
338338

339339
IntegerVariable GetOrCreateVariableLinkedToSumOf(
340-
const std::vector<std::pair<IntegerVariable, int64_t>>& terms,
340+
absl::Span<const std::pair<IntegerVariable, int64_t>> terms,
341341
bool lb_required, bool ub_required, Model* model) {
342342
if (terms.empty()) return model->Add(ConstantIntegerVariable(0));
343343
if (terms.size() == 1 && terms.front().second == 1) {
@@ -1862,7 +1862,7 @@ void PostsolveResponseWithFullSolver(int num_variables_in_original_model,
18621862
void PostsolveResponseWrapper(const SatParameters& params,
18631863
int num_variable_in_original_model,
18641864
const CpModelProto& mapping_proto,
1865-
const std::vector<int>& postsolve_mapping,
1865+
absl::Span<const int> postsolve_mapping,
18661866
std::vector<int64_t>* solution) {
18671867
if (params.debug_postsolve_with_full_solver()) {
18681868
PostsolveResponseWithFullSolver(num_variable_in_original_model,

ortools/sat/cp_model_solver_helpers.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <vector>
2323

2424
#include "absl/flags/declare.h"
25+
#include "absl/types/span.h"
2526
#include "ortools/base/timer.h"
2627
#include "ortools/sat/cp_model.pb.h"
2728
#include "ortools/sat/integer_base.h"
@@ -128,7 +129,7 @@ int RegisterClausesLevelZeroImport(int id,
128129
void PostsolveResponseWrapper(const SatParameters& params,
129130
int num_variable_in_original_model,
130131
const CpModelProto& mapping_proto,
131-
const std::vector<int>& postsolve_mapping,
132+
absl::Span<const int> postsolve_mapping,
132133
std::vector<int64_t>* solution);
133134

134135
// Try to find a solution by following the hint and using a low conflict limit.

ortools/sat/cuts.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2414,7 +2414,7 @@ IntegerValue SumOfAllDiffLowerBounder::GetBestLowerBound(std::string& suffix) {
24142414
namespace {
24152415

24162416
void TryToGenerateAllDiffCut(
2417-
const std::vector<std::pair<double, AffineExpression>>& sorted_exprs_lp,
2417+
absl::Span<const std::pair<double, AffineExpression>> sorted_exprs_lp,
24182418
const IntegerTrail& integer_trail,
24192419
const util_intops::StrongVector<IntegerVariable, double>& lp_values,
24202420
TopNCuts& top_n_cuts, Model* model) {
@@ -2527,8 +2527,8 @@ IntegerValue MaxCornerDifference(const IntegerVariable var,
25272527
// target expr I(i), max expr k.
25282528
// The coefficient of zk is Sum(i=1..n)(MPlusCoefficient_ki) + bk
25292529
IntegerValue MPlusCoefficient(
2530-
const std::vector<IntegerVariable>& x_vars,
2531-
const std::vector<LinearExpression>& exprs,
2530+
absl::Span<const IntegerVariable> x_vars,
2531+
absl::Span<const LinearExpression> exprs,
25322532
const util_intops::StrongVector<IntegerVariable, int>& variable_partition,
25332533
const int max_index, const IntegerTrail& integer_trail) {
25342534
IntegerValue coeff = exprs[max_index].offset;
@@ -2659,7 +2659,7 @@ IntegerValue EvaluateMaxAffine(
26592659

26602660
bool BuildMaxAffineUpConstraint(
26612661
const LinearExpression& target, IntegerVariable var,
2662-
const std::vector<std::pair<IntegerValue, IntegerValue>>& affines,
2662+
absl::Span<const std::pair<IntegerValue, IntegerValue>> affines,
26632663
Model* model, LinearConstraintBuilder* builder) {
26642664
auto* integer_trail = model->GetOrCreate<IntegerTrail>();
26652665
const IntegerValue x_min = integer_trail->LevelZeroLowerBound(var);

ortools/sat/cuts.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,7 @@ CutGenerator CreateLinMaxCutGenerator(
702702
// This function will reset the bounds of the builder.
703703
bool BuildMaxAffineUpConstraint(
704704
const LinearExpression& target, IntegerVariable var,
705-
const std::vector<std::pair<IntegerValue, IntegerValue>>& affines,
705+
absl::Span<const std::pair<IntegerValue, IntegerValue>> affines,
706706
Model* model, LinearConstraintBuilder* builder);
707707

708708
// By definition, the Max of affine functions is convex. The linear polytope is

ortools/sat/diffn.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ int NonOverlappingRectanglesEnergyPropagator::RegisterWith(
530530
}
531531

532532
bool NonOverlappingRectanglesEnergyPropagator::BuildAndReportEnergyTooLarge(
533-
const std::vector<RectangleInRange>& ranges) {
533+
absl::Span<const RectangleInRange> ranges) {
534534
if (ranges.size() == 2) {
535535
num_conflicts_two_boxes_++;
536536
return ClearAndAddTwoBoxesConflictReason(ranges[0].box_index,

ortools/sat/diffn.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ class NonOverlappingRectanglesEnergyPropagator : public PropagatorInterface {
6363

6464
std::vector<RectangleInRange> GeneralizeExplanation(const Conflict& conflict);
6565

66-
bool BuildAndReportEnergyTooLarge(
67-
const std::vector<RectangleInRange>& ranges);
66+
bool BuildAndReportEnergyTooLarge(absl::Span<const RectangleInRange> ranges);
6867

6968
SchedulingConstraintHelper& x_;
7069
SchedulingConstraintHelper& y_;

ortools/sat/diffn_util_test.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,11 @@ std::vector<std::vector<int>> GetOverlappingIntervalComponentsBruteForce(
315315
components[component_indices[i]].push_back(i);
316316
}
317317
// Sort the components by start, like GetOverlappingIntervalComponents().
318-
absl::c_sort(components, [intervals](const std::vector<int>& c1,
319-
const std::vector<int>& c2) {
320-
CHECK(!c1.empty() && !c2.empty());
321-
return intervals[c1[0]].start < intervals[c2[0]].start;
322-
});
318+
absl::c_sort(components,
319+
[intervals](absl::Span<const int> c1, absl::Span<const int> c2) {
320+
CHECK(!c1.empty() && !c2.empty());
321+
return intervals[c1[0]].start < intervals[c2[0]].start;
322+
});
323323
// Inside each component, the intervals should be sorted, too.
324324
// Moreover, we need to convert our indices to IntervalIndex.index.
325325
for (std::vector<int>& component : components) {
@@ -736,7 +736,7 @@ void ReduceUntilDone(ProbingRectangle& ranges, absl::BitGen& random) {
736736
// detect a conflict even if there is one by looking only at those rectangles,
737737
// see the ProbingRectangleTest.CounterExample unit test for a concrete example.
738738
std::optional<Rectangle> FindRectangleWithEnergyTooLargeExhaustive(
739-
const std::vector<RectangleInRange>& box_ranges) {
739+
absl::Span<const RectangleInRange> box_ranges) {
740740
int num_boxes = box_ranges.size();
741741
std::vector<IntegerValue> x;
742742
x.reserve(num_boxes * 4);
@@ -957,8 +957,8 @@ TEST(FindPartialIntersections, Simple) {
957957
}
958958

959959
bool GraphsDefineSameConnectedComponents(
960-
const std::vector<std::pair<int, int>>& graph1,
961-
const std::vector<std::pair<int, int>>& graph2) {
960+
absl::Span<const std::pair<int, int>> graph1,
961+
absl::Span<const std::pair<int, int>> graph2) {
962962
int max = -1;
963963
int max2 = -1;
964964
for (const auto& [a, b] : graph1) {

ortools/sat/integer_expr.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -603,10 +603,16 @@ inline std::function<void(Model*)> ConditionalWeightedSumLowerOrEqual(
603603
};
604604
}
605605
inline std::function<void(Model*)> ConditionalWeightedSumGreaterOrEqual(
606-
const std::vector<Literal>& enforcement_literals,
607-
const std::vector<IntegerVariable>& vars,
608-
const std::vector<int64_t>& coefficients, int64_t upper_bound) {
609-
return [=](Model* model) {
606+
absl::Span<const Literal> enforcement_literals,
607+
absl::Span<const IntegerVariable> vars,
608+
absl::Span<const int64_t> coefficients, int64_t upper_bound) {
609+
return [=,
610+
coefficients =
611+
std::vector<int64_t>(coefficients.begin(), coefficients.end()),
612+
vars = std::vector<IntegerVariable>(vars.begin(), vars.end()),
613+
enforcement_literals =
614+
std::vector<Literal>(enforcement_literals.begin(),
615+
enforcement_literals.end())](Model* model) {
610616
AddWeightedSumGreaterOrEqual(enforcement_literals, vars, coefficients,
611617
upper_bound, model);
612618
};

ortools/sat/integer_search.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,8 +1116,8 @@ std::function<BooleanOrIntegerLiteral()> RandomizeOnRestartHeuristic(
11161116
}
11171117

11181118
std::function<BooleanOrIntegerLiteral()> FollowHint(
1119-
const std::vector<BooleanOrIntegerVariable>& vars,
1120-
const std::vector<IntegerValue>& values, Model* model) {
1119+
absl::Span<const BooleanOrIntegerVariable> vars,
1120+
absl::Span<const IntegerValue> values, Model* model) {
11211121
auto* trail = model->GetOrCreate<Trail>();
11221122
auto* integer_trail = model->GetOrCreate<IntegerTrail>();
11231123
auto* rev_int_repo = model->GetOrCreate<RevIntRepository>();
@@ -1130,7 +1130,10 @@ std::function<BooleanOrIntegerLiteral()> FollowHint(
11301130
int* rev_start_index = model->TakeOwnership(new int);
11311131
*rev_start_index = 0;
11321132

1133-
return [=]() {
1133+
return [=,
1134+
vars =
1135+
std::vector<BooleanOrIntegerVariable>(vars.begin(), vars.end()),
1136+
values = std::vector<IntegerValue>(values.begin(), values.end())]() {
11341137
rev_int_repo->SaveState(rev_start_index);
11351138
for (int i = *rev_start_index; i < vars.size(); ++i) {
11361139
const IntegerValue value = values[i];

0 commit comments

Comments
 (0)