Skip to content

Commit 1d9299e

Browse files
committed
add necessary API and fix #5083
1 parent e865aa0 commit 1d9299e

File tree

4 files changed

+181
-2
lines changed

4 files changed

+181
-2
lines changed

ortools/constraint_solver/python/constraint_solver.cc

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ using ::operations_research::IntTupleSet;
6464
using ::operations_research::IntVar;
6565
using ::operations_research::IntVarElement;
6666
using ::operations_research::IntVarIterator;
67+
using ::operations_research::IntVarLocalSearchOperator;
68+
using ::operations_research::LocalSearchFilter;
6769
using ::operations_research::LocalSearchFilterManager;
6870
using ::operations_research::LocalSearchOperator;
6971
using ::operations_research::LocalSearchPhaseParameters;
@@ -299,6 +301,48 @@ class PyDecisionBuilderHelper : public PyDecisionBuilder {
299301
}
300302
};
301303

304+
class PyLocalSearchOperator : public LocalSearchOperator {
305+
public:
306+
using LocalSearchOperator::LocalSearchOperator;
307+
bool MakeNextNeighbor(Assignment* delta, Assignment* deltadelta) override {
308+
PYBIND11_OVERRIDE_PURE_NAME(bool, LocalSearchOperator, "next_neighbor",
309+
MakeNextNeighbor, delta, deltadelta);
310+
}
311+
void Start(const Assignment* assignment) override {
312+
PYBIND11_OVERRIDE_PURE_NAME(void, LocalSearchOperator, "start", Start,
313+
assignment);
314+
}
315+
void EnterSearch() override {
316+
PYBIND11_OVERRIDE_NAME(void, LocalSearchOperator, "enter_search",
317+
EnterSearch, );
318+
}
319+
void Reset() override {
320+
PYBIND11_OVERRIDE_NAME(void, LocalSearchOperator, "reset", Reset, );
321+
}
322+
bool HasFragments() const override {
323+
PYBIND11_OVERRIDE_NAME(bool, LocalSearchOperator, "has_fragments",
324+
HasFragments, );
325+
}
326+
bool HoldsDelta() const override {
327+
PYBIND11_OVERRIDE_NAME(bool, LocalSearchOperator, "holds_delta",
328+
HoldsDelta, );
329+
}
330+
};
331+
332+
class PyIntVarLocalSearchOperator : public IntVarLocalSearchOperator {
333+
public:
334+
using IntVarLocalSearchOperator::IntVarLocalSearchOperator;
335+
using IntVarLocalSearchOperator::MakeOneNeighbor;
336+
bool MakeOneNeighbor() override {
337+
PYBIND11_OVERRIDE_NAME(bool, IntVarLocalSearchOperator, "one_neighbor",
338+
MakeOneNeighbor, );
339+
}
340+
void OnStart() override {
341+
PYBIND11_OVERRIDE_NAME(void, IntVarLocalSearchOperator, "on_start",
342+
OnStart, );
343+
}
344+
};
345+
302346
class PySearchMonitor : public SearchMonitor {
303347
public:
304348
using SearchMonitor::SearchMonitor;
@@ -2540,6 +2584,41 @@ PYBIND11_MODULE(constraint_solver, m) {
25402584
},
25412585
py::return_value_policy::reference_internal);
25422586

2587+
py::class_<LocalSearchOperator, BaseObject, PyLocalSearchOperator>(
2588+
m, "LocalSearchOperator", DOC(operations_research, LocalSearchOperator))
2589+
.def(py::init<>())
2590+
.def("next_neighbor", &LocalSearchOperator::MakeNextNeighbor,
2591+
py::arg("delta"), py::arg("deltadelta"))
2592+
.def("start", &LocalSearchOperator::Start, py::arg("assignment"))
2593+
.def("enter_search", &LocalSearchOperator::EnterSearch)
2594+
.def("reset", &LocalSearchOperator::Reset)
2595+
.def("has_fragments", &LocalSearchOperator::HasFragments)
2596+
.def("holds_delta", &LocalSearchOperator::HoldsDelta);
2597+
2598+
py::class_<IntVarLocalSearchOperator, LocalSearchOperator,
2599+
PyIntVarLocalSearchOperator>(m, "IntVarLocalSearchOperator", "")
2600+
.def(py::init<const std::vector<IntVar*>&, bool>(), py::arg("vars"),
2601+
py::arg("keep_inverse_values") = false)
2602+
.def("size", &IntVarLocalSearchOperator::Size)
2603+
.def("var", &IntVarLocalSearchOperator::Var, py::arg("index"),
2604+
py::return_value_policy::reference_internal)
2605+
.def("is_incremental", &IntVarLocalSearchOperator::IsIncremental)
2606+
.def("value", &IntVarLocalSearchOperator::Value, py::arg("index"))
2607+
.def("old_value", &IntVarLocalSearchOperator::OldValue, py::arg("index"))
2608+
.def("prev_value", &IntVarLocalSearchOperator::PrevValue,
2609+
py::arg("index"))
2610+
.def("set_value", &IntVarLocalSearchOperator::SetValue, py::arg("index"),
2611+
py::arg("value"))
2612+
.def("activate", &IntVarLocalSearchOperator::Activate, py::arg("index"))
2613+
.def("deactivate", &IntVarLocalSearchOperator::Deactivate,
2614+
py::arg("index"))
2615+
.def("activated", &IntVarLocalSearchOperator::Activated, py::arg("index"))
2616+
.def("add_vars", &IntVarLocalSearchOperator::AddVars, py::arg("vars"))
2617+
.def(
2618+
"one_neighbor",
2619+
[](PyIntVarLocalSearchOperator* op) { return op->MakeOneNeighbor(); })
2620+
.def("on_start", &IntVarLocalSearchOperator::OnStart);
2621+
25432622
py::class_<DecisionBuilder, BaseObject>(
25442623
m, "DecisionBuilderBase", DOC(operations_research, DecisionBuilder))
25452624
.def("__str__", &DecisionBuilder::DebugString)

ortools/constraint_solver/python/constraint_solver_test.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def test_add_cumulative(self):
226226
]
227227

228228
solver.add_cumulative(intervals, demands, capacity, "cumul")
229+
229230
# Since capacity is 1 and demands are 1, they cannot overlap.
230231

231232
def check_single_solution():
@@ -623,6 +624,7 @@ def test_add_sub_circuit(self):
623624
solver = cp.Solver("test_sub_circuit")
624625
x = [solver.new_int_var(0, 2, f"x{i}") for i in range(3)]
625626
solver.add_sub_circuit(x)
627+
626628
# SubCircuit allows partial circuits (if x[i] == i, it's not in the
627629
# circuit). But if in circuit, must form a single circuit.
628630

@@ -1030,5 +1032,54 @@ def test_new_interval_relaxed_max(self):
10301032
self.assertIsNotNone(r)
10311033

10321034

1035+
class IntVarLocalSearchOperatorTest(absltest.TestCase):
1036+
1037+
def test_subclass_int_var_local_search_operator(self):
1038+
class MoveOneVar(cp.IntVarLocalSearchOperator):
1039+
1040+
def __init__(self, int_vars):
1041+
super().__init__(int_vars)
1042+
self.__index = 0
1043+
self.__up = False
1044+
1045+
def one_neighbor(self):
1046+
current_value = self.old_value(self.__index)
1047+
if self.__up:
1048+
self.set_value(self.__index, current_value + 1)
1049+
self.__index = (self.__index + 1) % self.size()
1050+
else:
1051+
self.set_value(self.__index, current_value - 1)
1052+
self.__up = not self.__up
1053+
return True
1054+
1055+
def on_start(self):
1056+
pass
1057+
1058+
solver = cp.Solver("test_subclass")
1059+
x = solver.new_int_var(0, 10, "x")
1060+
y = solver.new_int_var(0, 10, "y")
1061+
move_one_var = MoveOneVar([x, y])
1062+
self.assertIsNotNone(move_one_var)
1063+
1064+
1065+
class LocalSearchOperatorTest(absltest.TestCase):
1066+
1067+
def test_subclass_local_search_operator(self):
1068+
class CustomLSOperator(cp.LocalSearchOperator):
1069+
1070+
def __init__(self):
1071+
super().__init__()
1072+
1073+
def next_neighbor(self, delta, deltadelta):
1074+
return False
1075+
1076+
def start(self, assignment):
1077+
pass
1078+
1079+
solver = cp.Solver("test_subclass_ls")
1080+
op = CustomLSOperator()
1081+
self.assertIsNotNone(op)
1082+
1083+
10331084
if __name__ == "__main__":
10341085
absltest.main()

ortools/routing/python/routing.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -897,10 +897,12 @@ PYBIND11_MODULE(routing, m) {
897897

898898
rm.def("add_local_search_operator", &Model::AddLocalSearchOperator,
899899
py::arg("ls_operator"),
900-
DOC(operations_research, routing, Model, AddLocalSearchOperator));
900+
DOC(operations_research, routing, Model, AddLocalSearchOperator),
901+
py::keep_alive<1, 2>());
901902
rm.def("add_local_search_filter", &Model::AddLocalSearchFilter,
902903
py::arg("filter"),
903-
DOC(operations_research, routing, Model, AddLocalSearchFilter));
904+
DOC(operations_research, routing, Model, AddLocalSearchFilter),
905+
py::keep_alive<1, 2>());
904906
rm.def(
905907
"apply_locks",
906908
[](Model* model, const std::vector<int64_t>& locks) {

ortools/routing/python/routing_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,53 @@ def test_initial_solution(self) -> None:
639639

640640
self.assertIsNotNone(solution)
641641

642+
def test_issue5083(self):
643+
644+
class MyLocalSearch(constraint_solver.IntVarLocalSearchOperator):
645+
646+
def __init__(self, int_vars):
647+
constraint_solver.IntVarLocalSearchOperator.__init__(self, int_vars)
648+
self.__index = 0
649+
650+
def one_neighbor(self):
651+
current_value = self.old_value(self.__index)
652+
self.set_value(self.__index, 1 - current_value)
653+
self.__index = (self.__index + 1) % self.Size()
654+
return True
655+
656+
def on_start(self):
657+
pass
658+
659+
def is_incremental(self):
660+
return False
661+
662+
manager = routing.IndexManager(3, 1, 0)
663+
model = routing.Model(manager)
664+
model.add_to_assignment(model.active_var(manager.node_to_index(1)))
665+
model.add_to_assignment(model.active_var(manager.node_to_index(2)))
666+
model.add_disjunction([manager.node_to_index(1)], 1)
667+
model.add_disjunction([manager.node_to_index(2)], 1)
668+
669+
model.add_local_search_operator(
670+
MyLocalSearch(
671+
[
672+
model.active_var(manager.node_to_index(1)),
673+
model.active_var(manager.node_to_index(2)),
674+
]
675+
)
676+
)
677+
678+
search_parameters = routing.default_routing_search_parameters()
679+
search_parameters.first_solution_strategy = (
680+
enums_pb2.FirstSolutionStrategy.AUTOMATIC
681+
)
682+
search_parameters.local_search_metaheuristic = (
683+
enums_pb2.LocalSearchMetaheuristic.AUTOMATIC
684+
)
685+
686+
solution = model.solve_with_parameters(search_parameters)
687+
print(solution)
688+
642689

643690
if __name__ == "__main__":
644691
absltest.main()

0 commit comments

Comments
 (0)