Skip to content

Commit 53d2736

Browse files
committed
[CP-SAT] fix python layer
1 parent b53fe42 commit 53d2736

File tree

3 files changed

+49
-32
lines changed

3 files changed

+49
-32
lines changed

ortools/sat/python/cp_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,8 +312,9 @@ def rebuild_from_linear_expression_proto(
312312
if num_elements == 0:
313313
return proto.offset
314314
elif num_elements == 1:
315-
return (
316-
IntVar(model, proto.vars[0], False, None) * proto.coeffs[0] + proto.offset
315+
var = IntVar(model, proto.vars[0], False, None)
316+
return LinearExpr.affine(
317+
var, proto.coeffs[0], proto.offset
317318
) # pytype: disable=bad-return-type
318319
else:
319320
variables = []

ortools/sat/python/linear_expr.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,11 +171,13 @@ LinearExpr* LinearExpr::Term(LinearExpr* expr, int64_t coeff) {
171171
}
172172

173173
LinearExpr* LinearExpr::Affine(LinearExpr* expr, double coeff, double offset) {
174+
if (coeff == 1.0 && offset == 0.0) return expr;
174175
return new FloatAffine(expr, coeff, offset);
175176
}
176177

177178
LinearExpr* LinearExpr::Affine(LinearExpr* expr, int64_t coeff,
178179
int64_t offset) {
180+
if (coeff == 1 && offset == 0) return expr;
179181
return new IntAffine(expr, coeff, offset);
180182
}
181183

@@ -192,10 +194,12 @@ LinearExpr* LinearExpr::Add(LinearExpr* other) {
192194
}
193195

194196
LinearExpr* LinearExpr::AddInt(int64_t cst) {
197+
if (cst == 0) return this;
195198
return new IntAffine(this, 1, cst);
196199
}
197200

198201
LinearExpr* LinearExpr::AddDouble(double cst) {
202+
if (cst == 0.0) return this;
199203
return new FloatAffine(this, 1.0, cst);
200204
}
201205

ortools/sat/python/swig_helper.cc

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ class PySolutionCallback : public SolutionCallback {
5656
}
5757
};
5858

59+
void throw_error(PyObject* py_exception, const std::string& message) {
60+
PyErr_SetString(py_exception, message.c_str());
61+
throw py::error_already_set();
62+
}
63+
5964
// A trampoline class to override the __str__ and __repr__ methods.
6065
class PyBaseIntVar : public BaseIntVar {
6166
public:
@@ -140,9 +145,9 @@ class ResponseWrapper {
140145
IntExprVisitor visitor;
141146
int64_t value;
142147
if (!visitor.Evaluate(expr, response_, &value)) {
143-
LOG(ERROR) << "Failed to evaluate linear expression: "
144-
<< expr->DebugString();
145-
return -1;
148+
throw_error(PyExc_TypeError,
149+
absl::StrCat("Failed to evaluate linear expression: ",
150+
expr->DebugString()));
146151
}
147152
return value;
148153
}
@@ -155,11 +160,6 @@ class ResponseWrapper {
155160
const CpSolverResponse response_;
156161
};
157162

158-
void throw_error(PyObject* py_exception, const std::string& message) {
159-
PyErr_SetString(py_exception, message.c_str());
160-
throw py::error_already_set();
161-
}
162-
163163
const char* kLinearExprClassDoc = R"doc(
164164
Holds an integer linear expression.
165165
@@ -318,6 +318,8 @@ PYBIND11_MODULE(swig_helper, m) {
318318
py::implicitly_convertible<int64_t, ExprOrValue>();
319319

320320
py::class_<LinearExpr>(m, "LinearExpr", kLinearExprClassDoc)
321+
// We make sure to keep the order of the overloads: LinearExpr* before
322+
// ExprOrValue as this is faster to parse and type check.
321323
.def_static(
322324
"sum",
323325
py::overload_cast<const std::vector<LinearExpr*>&>(&LinearExpr::Sum),
@@ -346,24 +348,8 @@ PYBIND11_MODULE(swig_helper, m) {
346348
const std::vector<double>&>(
347349
&LinearExpr::WeightedSum),
348350
py::return_value_policy::automatic, py::keep_alive<0, 1>())
349-
.def_static(
350-
"Sum",
351-
py::overload_cast<const std::vector<LinearExpr*>&>(&LinearExpr::Sum),
352-
py::return_value_policy::automatic, py::keep_alive<0, 1>())
353-
.def_static(
354-
"Sum",
355-
py::overload_cast<const std::vector<ExprOrValue>&>(&LinearExpr::Sum),
356-
py::return_value_policy::automatic, py::keep_alive<0, 1>())
357-
.def_static("WeightedSum",
358-
py::overload_cast<const std::vector<ExprOrValue>&,
359-
const std::vector<double>&>(
360-
&LinearExpr::WeightedSum),
361-
py::return_value_policy::automatic, py::keep_alive<0, 1>())
362-
.def_static("WeightedSum",
363-
py::overload_cast<const std::vector<ExprOrValue>&,
364-
const std::vector<int64_t>&>(
365-
&LinearExpr::WeightedSum),
366-
py::return_value_policy::automatic, py::keep_alive<0, 1>())
351+
// Make sure to keep the order of the overloads: int before float as an
352+
// an integer value will be silently converted to a float.
367353
.def_static("term",
368354
py::overload_cast<LinearExpr*, int64_t>(&LinearExpr::Term),
369355
arg("expr"), arg("coeff"), "Returns expr * coeff.",
@@ -374,22 +360,48 @@ PYBIND11_MODULE(swig_helper, m) {
374360
py::return_value_policy::automatic, py::keep_alive<0, 1>())
375361
.def_static(
376362
"affine",
377-
py::overload_cast<LinearExpr*, double, double>(&LinearExpr::Affine),
363+
py::overload_cast<LinearExpr*, int64_t, int64_t>(&LinearExpr::Affine),
378364
arg("expr"), arg("coeff"), arg("offset"),
379365
"Returns expr * coeff + offset.", py::return_value_policy::automatic,
380366
py::keep_alive<0, 1>())
381367
.def_static(
382368
"affine",
383-
py::overload_cast<LinearExpr*, int64_t, int64_t>(&LinearExpr::Affine),
369+
py::overload_cast<LinearExpr*, double, double>(&LinearExpr::Affine),
384370
arg("expr"), arg("coeff"), arg("offset"),
385371
"Returns expr * coeff + offset.", py::return_value_policy::automatic,
386372
py::keep_alive<0, 1>())
387-
.def_static("constant", py::overload_cast<double>(&LinearExpr::Constant),
373+
.def_static("constant", py::overload_cast<int64_t>(&LinearExpr::Constant),
388374
arg("value"), "Returns a constant linear expression.",
389375
py::return_value_policy::automatic)
390-
.def_static("constant", py::overload_cast<int64_t>(&LinearExpr::Constant),
376+
.def_static("constant", py::overload_cast<double>(&LinearExpr::Constant),
391377
arg("value"), "Returns a constant linear expression.",
392378
py::return_value_policy::automatic)
379+
.def_static(
380+
"Sum",
381+
py::overload_cast<const std::vector<LinearExpr*>&>(&LinearExpr::Sum),
382+
py::return_value_policy::automatic, py::keep_alive<0, 1>())
383+
.def_static(
384+
"Sum",
385+
py::overload_cast<const std::vector<ExprOrValue>&>(&LinearExpr::Sum),
386+
py::return_value_policy::automatic, py::keep_alive<0, 1>())
387+
.def_static("WeightedSum",
388+
py::overload_cast<const std::vector<ExprOrValue>&,
389+
const std::vector<int64_t>&>(
390+
&LinearExpr::WeightedSum),
391+
py::return_value_policy::automatic, py::keep_alive<0, 1>())
392+
.def_static("WeightedSum",
393+
py::overload_cast<const std::vector<ExprOrValue>&,
394+
const std::vector<double>&>(
395+
&LinearExpr::WeightedSum),
396+
py::return_value_policy::automatic, py::keep_alive<0, 1>())
397+
.def_static("Term",
398+
py::overload_cast<LinearExpr*, int64_t>(&LinearExpr::Term),
399+
arg("expr"), arg("coeff"), "Returns expr * coeff.",
400+
py::return_value_policy::automatic, py::keep_alive<0, 1>())
401+
.def_static("Term",
402+
py::overload_cast<LinearExpr*, double>(&LinearExpr::Term),
403+
arg("expr"), arg("coeff"), "Returns expr * coeff.",
404+
py::return_value_policy::automatic, py::keep_alive<0, 1>())
393405
.def("__str__", &LinearExpr::ToString)
394406
.def("__repr__", &LinearExpr::DebugString)
395407
.def("is_integer", &LinearExpr::IsInteger)

0 commit comments

Comments
 (0)