Skip to content

Commit 6be2ba2

Browse files
committed
[CP-SAT] fix memory management of the python layer; properly fails with None arguments
1 parent 57f75f8 commit 6be2ba2

File tree

4 files changed

+103
-76
lines changed

4 files changed

+103
-76
lines changed

ortools/sat/python/cp_model_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1687,6 +1687,17 @@ def testIntervalVarSeries(self) -> None:
16871687
)
16881688
self.assertLen(model.proto.constraints, 13)
16891689

1690+
def testCompareWithNone(self) -> None:
1691+
print("testCompareWithNone")
1692+
model = cp_model.CpModel()
1693+
x = model.new_int_var(0, 10, "x")
1694+
self.assertRaises(TypeError, x.__eq__, None)
1695+
self.assertRaises(TypeError, x.__ne__, None)
1696+
self.assertRaises(TypeError, x.__lt__, None)
1697+
self.assertRaises(TypeError, x.__le__, None)
1698+
self.assertRaises(TypeError, x.__gt__, None)
1699+
self.assertRaises(TypeError, x.__ge__, None)
1700+
16901701
def testIssue4376SatModel(self) -> None:
16911702
print("testIssue4376SatModel")
16921703
letters: str = "BCFLMRT"

ortools/sat/python/linear_expr.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ double FloatExprVisitor::Process(FloatLinearExpr* expr,
153153
vars->clear();
154154
coeffs->clear();
155155
for (const auto& [var, coeff] : canonical_terms_) {
156-
if (coeff == 0) continue;
156+
if (coeff == 0.0) continue;
157157
vars->push_back(var);
158158
coeffs->push_back(coeff);
159159
}
@@ -473,6 +473,10 @@ bool BaseIntVarComparator::operator()(const BaseIntVar* lhs,
473473
return lhs->index() < rhs->index();
474474
}
475475

476+
BaseIntVar::BaseIntVar(int index, bool is_boolean)
477+
: index_(index),
478+
negated_(is_boolean ? new NotBooleanVariable(this) : nullptr) {}
479+
476480
BoundedLinearExpression::BoundedLinearExpression(IntLinExpr* expr,
477481
const Domain& bounds)
478482
: bounds_(bounds) {

ortools/sat/python/linear_expr.h

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -481,22 +481,20 @@ class Literal {
481481
public:
482482
virtual ~Literal() = default;
483483
virtual int index() const = 0;
484-
virtual Literal* negated() = 0;
484+
virtual Literal* negated() const = 0;
485485
};
486486

487487
// A class to hold a variable index.
488488
class BaseIntVar : public IntLinExpr, public Literal {
489489
public:
490-
explicit BaseIntVar(int index)
491-
: index_(index), is_boolean_(false), negated_(nullptr) {
492-
DCHECK_GE(index, 0);
493-
}
494-
BaseIntVar(int index, bool is_boolean)
495-
: index_(index), is_boolean_(is_boolean), negated_(nullptr) {
490+
explicit BaseIntVar(int index) : index_(index), negated_(nullptr) {
496491
DCHECK_GE(index, 0);
497492
}
493+
BaseIntVar(int index, bool is_boolean);
498494

499-
~BaseIntVar() override = default;
495+
~BaseIntVar() override {
496+
if (negated_ != nullptr) delete negated_;
497+
}
500498

501499
int index() const override { return index_; }
502500

@@ -509,7 +507,7 @@ class BaseIntVar : public IntLinExpr, public Literal {
509507
}
510508

511509
std::string ToString() const override {
512-
if (is_boolean_) {
510+
if (negated_ != nullptr) {
513511
return absl::StrCat("BooleanBaseIntVar(", index_, ")");
514512
} else {
515513
return absl::StrCat("BaseIntVar(", index_, ")");
@@ -518,23 +516,27 @@ class BaseIntVar : public IntLinExpr, public Literal {
518516

519517
std::string DebugString() const override {
520518
return absl::StrCat("BaseIntVar(index=", index_,
521-
", is_boolean=", is_boolean_, ")");
519+
", is_boolean=", negated_ != nullptr, ")");
522520
}
523521

524-
Literal* negated() override;
522+
Literal* negated() const override { return negated_; }
525523

526-
bool is_boolean() const { return is_boolean_; }
524+
bool is_boolean() const { return negated_ != nullptr; }
527525

528526
bool operator<(const BaseIntVar& other) const {
529527
return index_ < other.index_;
530528
}
531529

532530
protected:
533531
const int index_;
534-
bool is_boolean_;
535-
Literal* negated_;
532+
Literal* const negated_;
536533
};
537534

535+
template <typename H>
536+
H AbslHashValue(H h, const BaseIntVar* i) {
537+
return H::combine(std::move(h), i->index());
538+
}
539+
538540
// A class to hold a negated variable index.
539541
class NotBooleanVariable : public IntLinExpr, public Literal {
540542
public:
@@ -557,7 +559,7 @@ class NotBooleanVariable : public IntLinExpr, public Literal {
557559
return absl::StrCat("not(", var_->ToString(), ")");
558560
}
559561

560-
Literal* negated() override { return var_; }
562+
Literal* negated() const override { return var_; }
561563

562564
std::string DebugString() const override {
563565
return absl::StrCat("NotBooleanVariable(index=", var_->index(), ")");
@@ -567,13 +569,6 @@ class NotBooleanVariable : public IntLinExpr, public Literal {
567569
BaseIntVar* var_;
568570
};
569571

570-
inline Literal* BaseIntVar::negated() {
571-
if (negated_ == nullptr) {
572-
negated_ = new NotBooleanVariable(this);
573-
}
574-
return negated_;
575-
}
576-
577572
} // namespace python
578573
} // namespace sat
579574
} // namespace operations_research

ortools/sat/python/swig_helper.cc

Lines changed: 70 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -340,29 +340,29 @@ PYBIND11_MODULE(swig_helper, m) {
340340
.def("__str__", &FloatLinearExpr::ToString)
341341
.def("__repr__", &FloatLinearExpr::DebugString)
342342
.def("is_integer", &FloatLinearExpr::is_integer)
343-
.def("__add__", &FloatLinearExpr::FloatAddCst,
343+
.def("__add__", &FloatLinearExpr::FloatAddCst, arg("cst"),
344344
py::return_value_policy::automatic, py::keep_alive<0, 1>())
345-
.def("__add__", &FloatLinearExpr::FloatAdd,
345+
.def("__add__", &FloatLinearExpr::FloatAdd, arg("other").none(false),
346346
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
347347
py::keep_alive<0, 2>())
348-
.def("__radd__", &FloatLinearExpr::FloatAddCst,
348+
.def("__radd__", &FloatLinearExpr::FloatAddCst, arg("cst"),
349349
py::return_value_policy::automatic, py::keep_alive<0, 1>())
350-
.def("__radd__", &FloatLinearExpr::FloatAdd,
350+
.def("__radd__", &FloatLinearExpr::FloatAdd, arg("other").none(false),
351351
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
352352
py::keep_alive<0, 2>())
353-
.def("__sub__", &FloatLinearExpr::FloatSub,
353+
.def("__sub__", &FloatLinearExpr::FloatSub, arg("other").none(false),
354354
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
355355
py::keep_alive<0, 2>())
356-
.def("__sub__", &FloatLinearExpr::FloatSubCst,
356+
.def("__sub__", &FloatLinearExpr::FloatSubCst, arg("cst"),
357357
py::return_value_policy::automatic, py::keep_alive<0, 1>())
358-
.def("__rsub__", &FloatLinearExpr::FloatRSub,
358+
.def("__rsub__", &FloatLinearExpr::FloatRSub, arg("other").none(false),
359359
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
360360
py::keep_alive<0, 2>())
361-
.def("__rsub__", &FloatLinearExpr::FloatRSubCst,
361+
.def("__rsub__", &FloatLinearExpr::FloatRSubCst, arg("cst"),
362362
py::return_value_policy::automatic, py::keep_alive<0, 1>())
363-
.def("__mul__", &FloatLinearExpr::FloatMulCst,
363+
.def("__mul__", &FloatLinearExpr::FloatMulCst, arg("cst"),
364364
py::return_value_policy::automatic, py::keep_alive<0, 1>())
365-
.def("__rmul__", &FloatLinearExpr::FloatMulCst,
365+
.def("__rmul__", &FloatLinearExpr::FloatMulCst, arg("cst"),
366366
py::return_value_policy::automatic, py::keep_alive<0, 1>())
367367
.def("__neg__", &FloatLinearExpr::FloatNeg,
368368
py::return_value_policy::automatic, py::keep_alive<0, 1>());
@@ -428,60 +428,65 @@ PYBIND11_MODULE(swig_helper, m) {
428428
"Returns a constant linear expression.",
429429
py::return_value_policy::automatic)
430430
.def("is_integer", &IntLinExpr::is_integer)
431-
.def("__add__", &IntLinExpr::IntAddCst,
431+
.def("__add__", &IntLinExpr::IntAddCst, arg("cst"),
432432
py::return_value_policy::automatic, py::keep_alive<0, 1>())
433-
.def("__add__", &FloatLinearExpr::FloatAddCst,
433+
.def("__add__", &FloatLinearExpr::FloatAddCst, arg("cst"),
434434
py::return_value_policy::automatic, py::keep_alive<0, 1>())
435-
.def("__add__", &IntLinExpr::IntAdd, py::return_value_policy::automatic,
436-
py::keep_alive<0, 1>(), py::keep_alive<0, 2>())
437-
.def("__add__", &FloatLinearExpr::FloatAdd,
435+
.def("__add__", &IntLinExpr::IntAdd, arg("other").none(false),
436+
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
437+
py::keep_alive<0, 2>())
438+
.def("__add__", &FloatLinearExpr::FloatAdd, arg("other").none(false),
438439
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
439440
py::keep_alive<0, 2>())
440-
.def("__radd__", &IntLinExpr::IntAddCst,
441+
.def("__radd__", &IntLinExpr::IntAddCst, arg("cst"),
441442
py::return_value_policy::automatic, py::keep_alive<0, 1>())
442443
.def("__radd__", &FloatLinearExpr::FloatAddCst,
443444
py::return_value_policy::automatic, py::keep_alive<0, 1>())
444-
.def("__radd__", &FloatLinearExpr::FloatAdd,
445+
.def("__radd__", &FloatLinearExpr::FloatAdd, arg("other").none(false),
445446
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
446447
py::keep_alive<0, 2>())
447-
.def("__sub__", &IntLinExpr::IntSubCst,
448+
.def("__sub__", &IntLinExpr::IntSubCst, arg("cst"),
448449
py::return_value_policy::automatic, py::keep_alive<0, 1>())
449-
.def("__sub__", &FloatLinearExpr::FloatSubCst,
450+
.def("__sub__", &FloatLinearExpr::FloatSubCst, arg("cst"),
450451
py::return_value_policy::automatic, py::keep_alive<0, 1>())
451-
.def("__sub__", &FloatLinearExpr::FloatSubCst,
452+
.def("__sub__", &FloatLinearExpr::FloatSubCst, arg("cst"),
452453
py::return_value_policy::automatic, py::keep_alive<0, 1>())
453-
.def("__sub__", &IntLinExpr::IntSub, py::return_value_policy::automatic,
454-
py::keep_alive<0, 1>(), py::keep_alive<0, 2>())
455-
.def("__sub__", &FloatLinearExpr::FloatSub,
454+
.def("__sub__", &IntLinExpr::IntSub, arg("other").none(false),
455+
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
456+
py::keep_alive<0, 2>())
457+
.def("__sub__", &FloatLinearExpr::FloatSub, arg("other").none(false),
456458
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
457459
py::keep_alive<0, 2>())
458-
.def("__rsub__", &IntLinExpr::IntRSubCst,
460+
.def("__rsub__", &IntLinExpr::IntRSubCst, arg("cst"),
459461
py::return_value_policy::automatic, py::keep_alive<0, 1>())
460-
.def("__rsub__", &FloatLinearExpr::FloatRSubCst,
462+
.def("__rsub__", &FloatLinearExpr::FloatRSubCst, arg("cst"),
461463
py::return_value_policy::automatic, py::keep_alive<0, 1>())
462-
.def("__rsub__", &FloatLinearExpr::FloatRSub,
464+
.def("__rsub__", &FloatLinearExpr::FloatRSub, arg("cst"),
463465
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
464466
py::keep_alive<0, 2>())
465-
.def("__mul__", &IntLinExpr::IntMulCst,
467+
.def("__mul__", &IntLinExpr::IntMulCst, arg("cst"),
466468
py::return_value_policy::automatic, py::keep_alive<0, 1>())
467-
.def("__rmul__", &IntLinExpr::IntMulCst,
469+
.def("__rmul__", &IntLinExpr::IntMulCst, arg("cst"),
468470
py::return_value_policy::automatic, py::keep_alive<0, 1>())
469471
.def("__neg__", &IntLinExpr::IntNeg, py::return_value_policy::automatic,
470472
py::keep_alive<0, 1>())
471-
.def("__mul__", &FloatLinearExpr::FloatMulCst,
473+
.def("__mul__", &FloatLinearExpr::FloatMulCst, arg("cst"),
472474
py::return_value_policy::automatic, py::keep_alive<0, 1>())
473-
.def("__rmul__", &FloatLinearExpr::FloatMulCst,
475+
.def("__rmul__", &FloatLinearExpr::FloatMulCst, arg("cst"),
474476
py::return_value_policy::automatic, py::keep_alive<0, 1>())
475-
.def("__eq__", &IntLinExpr::Eq, py::return_value_policy::automatic,
476-
py::keep_alive<0, 1>(), py::keep_alive<0, 2>())
477+
.def("__eq__", &IntLinExpr::Eq, arg("other").none(false),
478+
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
479+
py::keep_alive<0, 2>())
477480
.def("__eq__", &IntLinExpr::EqCst, py::return_value_policy::automatic,
478481
py::keep_alive<0, 1>())
479-
.def("__ne__", &IntLinExpr::Ne, py::return_value_policy::automatic,
480-
py::keep_alive<0, 1>(), py::keep_alive<0, 2>())
482+
.def("__ne__", &IntLinExpr::Ne, arg("other").none(false),
483+
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
484+
py::keep_alive<0, 2>())
481485
.def("__ne__", &IntLinExpr::NeCst, py::return_value_policy::automatic,
482486
py::keep_alive<0, 1>())
483-
.def("__lt__", &IntLinExpr::Lt, py::return_value_policy::automatic,
484-
py::keep_alive<0, 1>(), py::keep_alive<0, 2>())
487+
.def("__lt__", &IntLinExpr::Lt, arg("other").none(false),
488+
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
489+
py::keep_alive<0, 2>())
485490
.def(
486491
"__lt__",
487492
[](IntLinExpr* expr, int64_t bound) {
@@ -491,8 +496,9 @@ PYBIND11_MODULE(swig_helper, m) {
491496
return expr->LtCst(bound);
492497
},
493498
py::return_value_policy::automatic, py::keep_alive<0, 1>())
494-
.def("__le__", &IntLinExpr::Le, py::return_value_policy::automatic,
495-
py::keep_alive<0, 1>(), py::keep_alive<0, 2>())
499+
.def("__le__", &IntLinExpr::Le, arg("other").none(false),
500+
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
501+
py::keep_alive<0, 2>())
496502
.def(
497503
"__le__",
498504
[](IntLinExpr* expr, int64_t bound) {
@@ -504,8 +510,9 @@ PYBIND11_MODULE(swig_helper, m) {
504510
py::return_value_policy::automatic,
505511

506512
py::keep_alive<0, 1>())
507-
.def("__gt__", &IntLinExpr::Gt, py::return_value_policy::automatic,
508-
py::keep_alive<0, 1>(), py::keep_alive<0, 2>())
513+
.def("__gt__", &IntLinExpr::Gt, arg("other").none(false),
514+
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
515+
py::keep_alive<0, 2>())
509516
.def(
510517
"__gt__",
511518
[](IntLinExpr* expr, int64_t bound) {
@@ -515,8 +522,9 @@ PYBIND11_MODULE(swig_helper, m) {
515522
return expr->GtCst(bound);
516523
},
517524
py::return_value_policy::automatic, py::keep_alive<0, 1>())
518-
.def("__ge__", &IntLinExpr::Ge, py::return_value_policy::automatic,
519-
py::keep_alive<0, 1>(), py::keep_alive<0, 2>())
525+
.def("__ge__", &IntLinExpr::Ge, arg("other").none(false),
526+
py::return_value_policy::automatic, py::keep_alive<0, 1>(),
527+
py::keep_alive<0, 2>())
520528
.def(
521529
"__ge__",
522530
[](IntLinExpr* expr, int64_t bound) {
@@ -684,21 +692,27 @@ PYBIND11_MODULE(swig_helper, m) {
684692
It is only valid if the variable has a Boolean domain (0 or 1).
685693
686694
Note that this method is nilpotent: `x.negated().negated() == x`.
687-
)doc",
688-
py::return_value_policy::automatic, py::keep_alive<1, 0>())
695+
)doc")
689696
.def("__invert__", &Literal::negated,
690-
"Returns the negation of the current literal.",
691-
py::return_value_policy::automatic)
697+
"Returns the negation of the current literal.")
692698
.def("__bool__",
693699
[](Literal* /*self*/) {
694700
throw_error(PyExc_NotImplementedError,
695701
"Evaluating a Literal instance as a Boolean is "
696702
"not implemented.");
697703
})
698704
// PEP8 Compatibility.
699-
.def("Not", &Literal::negated, py::return_value_policy::automatic)
705+
.def("Not", &Literal::negated)
700706
.def("Index", &Literal::index);
701707

708+
// Memory management:
709+
// - The BaseIntVar owns the NotBooleanVariable.
710+
// - The NotBooleanVariable is created at the same time as the base variable
711+
// when the variable is boolean.
712+
// - The negated() methods return an internal reference to the negated
713+
// object. That means memory of the negated variable is onwed by the C++
714+
// layer, but a reference is kept in python to link the lifetime of the
715+
// negated variable to the base variable.
702716
py::class_<BaseIntVar, PyBaseIntVar, IntLinExpr, Literal>(m, "BaseIntVar")
703717
.def(py::init<int>())
704718
.def(py::init<int, bool>())
@@ -718,7 +732,7 @@ PYBIND11_MODULE(swig_helper, m) {
718732
return self->negated();
719733
},
720734
"Returns the negation of the current Boolean variable.",
721-
py::return_value_policy::automatic, py::keep_alive<1, 0>())
735+
py::return_value_policy::reference_internal)
722736
.def(
723737
"__invert__",
724738
[](BaseIntVar* self) {
@@ -729,7 +743,7 @@ PYBIND11_MODULE(swig_helper, m) {
729743
return self->negated();
730744
},
731745
"Returns the negation of the current Boolean variable.",
732-
py::return_value_policy::automatic, py::keep_alive<1, 0>())
746+
py::return_value_policy::reference_internal)
733747
// PEP8 Compatibility.
734748
.def(
735749
"Not",
@@ -740,8 +754,11 @@ PYBIND11_MODULE(swig_helper, m) {
740754
}
741755
return self->negated();
742756
},
743-
py::return_value_policy::automatic, py::keep_alive<1, 0>());
757+
py::return_value_policy::reference_internal);
744758

759+
// Memory management:
760+
// - Do we need a reference_internal (that add a py::keep_alive<1, 0>() rule)
761+
// or just a reference ?
745762
py::class_<NotBooleanVariable, IntLinExpr, Literal>(m, "NotBooleanVariable")
746763
.def(py::init<BaseIntVar*>())
747764
.def_property_readonly("index", &NotBooleanVariable::index,
@@ -750,13 +767,13 @@ PYBIND11_MODULE(swig_helper, m) {
750767
.def("__repr__", &NotBooleanVariable::DebugString)
751768
.def("negated", &NotBooleanVariable::negated,
752769
"Returns the negation of the current Boolean variable.",
753-
py::return_value_policy::automatic)
770+
py::return_value_policy::reference_internal)
754771
.def("__invert__", &NotBooleanVariable::negated,
755772
"Returns the negation of the current Boolean variable.",
756-
py::return_value_policy::automatic)
773+
py::return_value_policy::reference_internal)
757774
.def("Not", &NotBooleanVariable::negated,
758775
"Returns the negation of the current Boolean variable.",
759-
py::return_value_policy::automatic);
776+
py::return_value_policy::reference_internal);
760777

761778
py::class_<BoundedLinearExpression>(m, "BoundedLinearExpression")
762779
.def(py::init<IntLinExpr*, const Domain&>())

0 commit comments

Comments
 (0)