Skip to content

Commit 492fca3

Browse files
committed
[CP-SAT] polish python API and add more tests
1 parent 5dbcdbc commit 492fca3

File tree

2 files changed

+97
-9
lines changed

2 files changed

+97
-9
lines changed

ortools/sat/python/cp_model_helper.cc

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,17 @@ void ProcessExprArg(const py::handle& arg, LinearExpr*& expr,
205205
float_value = arg.cast<double>();
206206
}
207207
} else {
208-
ThrowError(PyExc_TypeError,
209-
absl::StrCat("LinearExpr::sum() only accept linear "
210-
"expressions and constants as argument: ",
211-
arg.cast<std::string>()));
208+
try {
209+
expr = arg.cast<LinearExpr*>();
210+
ThrowError(PyExc_TypeError,
211+
absl::StrCat("LinearExpr::sum() only accept linear "
212+
"expressions and constants as argument: ",
213+
arg.cast<std::string>()));
214+
} catch (py::cast_error& e) {
215+
ThrowError(PyExc_TypeError,
216+
"LinearExpr::sum() only accept linear expressions and "
217+
"constants as argument.");
218+
}
212219
}
213220
}
214221

@@ -255,10 +262,7 @@ LinearExpr* SumArguments(py::args expressions) {
255262
}
256263
};
257264

258-
if (expressions.size() == 0) {
259-
return new IntConstant(0);
260-
} else if (expressions.size() == 1 &&
261-
py::isinstance<py::sequence>(expressions[0])) {
265+
if (expressions.size() == 1 && py::isinstance<py::sequence>(expressions[0])) {
262266
// Normal list or tuple argument.
263267
py::sequence elements = expressions[0].cast<py::sequence>();
264268
linear_exprs.reserve(elements.size());
@@ -339,7 +343,7 @@ LinearExpr* WeightedSumArguments(py::sequence expressions,
339343
int_offset += int_mult * int_value;
340344
float_offset += (float_mult + static_cast<double>(int_mult)) *
341345
static_cast<double>(int_value);
342-
} else if (float_value != 0.0) {
346+
} else {
343347
float_offset +=
344348
(float_mult + static_cast<double>(int_mult)) * float_value;
345349
}

ortools/sat/python/cp_model_test.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,9 +561,93 @@ def testSumParsing(self) -> None:
561561
self.assertLen(flat_s7.vars, 2)
562562
self.assertEqual(3.5, flat_s7.offset)
563563

564+
s8 = cp_model.LinearExpr.sum(x[0], 3)
565+
self.assertTrue(s8.is_integer())
566+
self.assertIsInstance(s8, cmh.IntAffine)
567+
self.assertEqual(s8.expression, x[0])
568+
self.assertEqual(s8.coefficient, 1)
569+
self.assertEqual(s8.offset, 1)
570+
571+
s9 = cp_model.LinearExpr.sum(x[0], -2.1)
572+
self.assertFalse(s9.is_integer())
573+
self.assertIsInstance(s9, cmh.FloatAffine)
574+
self.assertEqual(s9.expression, x[0])
575+
self.assertEqual(s9.coefficient, 1.0)
576+
self.assertEqual(s9.offset, -2.1)
577+
578+
s10 = cp_model.LinearExpr.sum(x[0], 1, -1)
579+
self.assertTrue(s10.is_integer())
580+
self.assertIsInstance(s10, cp_model.IntVar)
581+
self.assertEqual(s10, x[0])
582+
583+
s11 = cp_model.LinearExpr.sum(x[0])
584+
self.assertTrue(s11.is_integer())
585+
self.assertIsInstance(s11, cp_model.IntVar)
586+
self.assertEqual(s11, x[0])
587+
588+
class FakeNpDTypeA:
589+
590+
def __init__(self):
591+
self.dtype = 2
592+
pass
593+
594+
def __str__(self):
595+
return "FakeNpDTypeA"
596+
597+
class FakeNpDTypeB:
598+
599+
def __init__(self):
600+
self.is_integer = False
601+
pass
602+
603+
def __str__(self):
604+
return "FakeNpDTypeB"
605+
564606
with self.assertRaises(TypeError):
565607
cp_model.LinearExpr.sum(x[0], x[2], "foo")
566608

609+
with self.assertRaises(TypeError):
610+
cp_model.LinearExpr.sum(x[0], x[2], FakeNpDTypeA())
611+
612+
with self.assertRaises(TypeError):
613+
cp_model.LinearExpr.sum(x[0], x[2], FakeNpDTypeB())
614+
615+
def testWeightedSumParsing(self) -> None:
616+
model = cp_model.CpModel()
617+
x = [model.new_int_var(0, 2, "x%i" % i) for i in range(5)]
618+
c = [1, -2, 2, 3, 0.0]
619+
float_c = [1, -1.0, 2, 3, 0.0]
620+
621+
s1 = cp_model.LinearExpr.weighted_sum(x, c)
622+
self.assertTrue(s1.is_integer())
623+
flat_s1 = cp_model.FlatIntExpr(s1)
624+
self.assertLen(flat_s1.vars, 4)
625+
self.assertEqual(0, flat_s1.offset)
626+
627+
s2 = cp_model.LinearExpr.weighted_sum(x, float_c)
628+
self.assertFalse(s2.is_integer())
629+
flat_s2 = cp_model.FlatFloatExpr(s2)
630+
self.assertLen(flat_s2.vars, 4)
631+
self.assertEqual(0, flat_s2.offset)
632+
633+
s3 = cp_model.LinearExpr.weighted_sum(x + [2], c + [-1])
634+
self.assertTrue(s3.is_integer())
635+
flat_s3 = cp_model.FlatIntExpr(s3)
636+
self.assertLen(flat_s3.vars, 4)
637+
self.assertEqual(-2, flat_s3.offset)
638+
639+
s4 = cp_model.LinearExpr.weighted_sum(x + [2], float_c + [-1.0])
640+
self.assertFalse(s4.is_integer())
641+
flat_s4 = cp_model.FlatFloatExpr(s4)
642+
self.assertLen(flat_s4.vars, 4)
643+
self.assertEqual(-2, flat_s4.offset)
644+
645+
s5 = cp_model.LinearExpr.weighted_sum(x + [np.int16(2)], c + [-1])
646+
self.assertTrue(s5.is_integer())
647+
flat_s5 = cp_model.FlatIntExpr(s5)
648+
self.assertLen(flat_s5.vars, 4)
649+
self.assertEqual(-2, flat_s5.offset)
650+
567651
def testSumWithApi(self) -> None:
568652
model = cp_model.CpModel()
569653
x = [model.new_int_var(0, 2, "x%i" % i) for i in range(100)]

0 commit comments

Comments
 (0)