Skip to content

Commit ba26576

Browse files
committed
[FZ] more work on native set support
1 parent 30805fd commit ba26576

File tree

9 files changed

+2692
-1895
lines changed

9 files changed

+2692
-1895
lines changed

ortools/flatzinc/checker.cc

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <cstdint>
1818
#include <cstdlib>
1919
#include <functional>
20+
#include <iterator>
2021
#include <limits>
2122
#include <string>
2223
#include <utility>
@@ -99,6 +100,35 @@ std::vector<int64_t> SetEval(
99100
}
100101
}
101102

103+
std::vector<int64_t> SetEvalAt(
104+
const Argument& arg, int pos,
105+
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
106+
switch (arg.type) {
107+
case Argument::DOMAIN_LIST: {
108+
const Domain& domain = arg.domains[pos];
109+
if (domain.empty()) {
110+
return {};
111+
} else if (domain.is_interval) {
112+
std::vector<int64_t> result;
113+
result.reserve(domain.Max() - domain.Min() + 1);
114+
for (int64_t i = domain.Min(); i <= domain.Max(); ++i) {
115+
result.push_back(i);
116+
}
117+
return result;
118+
} else {
119+
return domain.values;
120+
}
121+
}
122+
case Argument::VAR_REF_ARRAY: {
123+
return set_evaluator(arg.variables[pos]);
124+
}
125+
default: {
126+
LOG(FATAL) << "Cannot evaluate " << arg.DebugString();
127+
return {};
128+
}
129+
}
130+
}
131+
102132
int64_t SetSize(
103133
const Argument& arg,
104134
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
@@ -1210,6 +1240,17 @@ bool CheckSetCard(
12101240
return set_size == cardinality;
12111241
}
12121242

1243+
bool CheckArraySetElement(
1244+
const Constraint& ct, const std::function<int64_t(Variable*)>& evaluator,
1245+
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
1246+
const int64_t index = Eval(ct.arguments[0], evaluator);
1247+
const int64_t min_index = ct.arguments[0].Var()->domain.Min();
1248+
const std::vector<int64_t> element =
1249+
SetEvalAt(ct.arguments[1], index - min_index, set_evaluator);
1250+
const std::vector<int64_t> target = SetEval(ct.arguments[2], set_evaluator);
1251+
return element == target;
1252+
}
1253+
12131254
bool CheckSetIn(
12141255
const Constraint& ct, const std::function<int64_t(Variable*)>& evaluator,
12151256
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
@@ -1236,6 +1277,142 @@ bool CheckSetInReif(
12361277
return contain == (status == 1);
12371278
}
12381279

1280+
bool CheckSetIntersect(
1281+
const Constraint& ct, const std::function<int64_t(Variable*)>& evaluator,
1282+
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
1283+
const std::vector<int64_t> values_x = SetEval(ct.arguments[0], set_evaluator);
1284+
const std::vector<int64_t> values_y = SetEval(ct.arguments[1], set_evaluator);
1285+
const std::vector<int64_t> values_r = SetEval(ct.arguments[2], set_evaluator);
1286+
absl::flat_hash_set<int64_t> set_x(values_x.begin(), values_x.end());
1287+
absl::flat_hash_set<int64_t> set_y(values_y.begin(), values_y.end());
1288+
absl::flat_hash_set<int64_t> set_r(values_r.begin(), values_r.end());
1289+
absl::flat_hash_set<int64_t> computed_intersection;
1290+
std::set_intersection(
1291+
values_x.begin(), values_x.end(), values_y.begin(), values_y.end(),
1292+
std::inserter(computed_intersection, computed_intersection.begin()));
1293+
return computed_intersection == set_r;
1294+
}
1295+
1296+
bool CheckSetUnion(
1297+
const Constraint& ct, const std::function<int64_t(Variable*)>& evaluator,
1298+
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
1299+
const std::vector<int64_t> values_x = SetEval(ct.arguments[0], set_evaluator);
1300+
const std::vector<int64_t> values_y = SetEval(ct.arguments[1], set_evaluator);
1301+
const std::vector<int64_t> values_r = SetEval(ct.arguments[2], set_evaluator);
1302+
absl::flat_hash_set<int64_t> set_x(values_x.begin(), values_x.end());
1303+
absl::flat_hash_set<int64_t> set_y(values_y.begin(), values_y.end());
1304+
absl::flat_hash_set<int64_t> set_r(values_r.begin(), values_r.end());
1305+
absl::flat_hash_set<int64_t> computed_intersection;
1306+
std::set_union(
1307+
values_x.begin(), values_x.end(), values_y.begin(), values_y.end(),
1308+
std::inserter(computed_intersection, computed_intersection.begin()));
1309+
return computed_intersection == set_r;
1310+
}
1311+
1312+
bool CheckSetSubset(
1313+
const Constraint& ct, const std::function<int64_t(Variable*)>& evaluator,
1314+
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
1315+
const std::vector<int64_t> values_x = SetEval(ct.arguments[0], set_evaluator);
1316+
const std::vector<int64_t> values_y = SetEval(ct.arguments[1], set_evaluator);
1317+
return std::includes(values_y.begin(), values_y.end(), values_x.begin(),
1318+
values_x.end());
1319+
}
1320+
1321+
bool CheckSetSubsetReif(
1322+
const Constraint& ct, const std::function<int64_t(Variable*)>& evaluator,
1323+
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
1324+
const std::vector<int64_t> values_x = SetEval(ct.arguments[0], set_evaluator);
1325+
const std::vector<int64_t> values_y = SetEval(ct.arguments[1], set_evaluator);
1326+
const bool status = Eval(ct.arguments[2], evaluator) != 0;
1327+
return std::includes(values_y.begin(), values_y.end(), values_x.begin(),
1328+
values_x.end()) == status;
1329+
}
1330+
1331+
bool CheckSetSuperset(
1332+
const Constraint& ct, const std::function<int64_t(Variable*)>& evaluator,
1333+
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
1334+
const std::vector<int64_t> values_x = SetEval(ct.arguments[0], set_evaluator);
1335+
const std::vector<int64_t> values_y = SetEval(ct.arguments[1], set_evaluator);
1336+
return std::includes(values_x.begin(), values_x.end(), values_y.begin(),
1337+
values_y.end());
1338+
}
1339+
1340+
bool CheckSetSupersetReif(
1341+
const Constraint& ct, const std::function<int64_t(Variable*)>& evaluator,
1342+
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
1343+
const std::vector<int64_t> values_x = SetEval(ct.arguments[0], set_evaluator);
1344+
const std::vector<int64_t> values_y = SetEval(ct.arguments[1], set_evaluator);
1345+
const bool status = Eval(ct.arguments[2], evaluator) != 0;
1346+
return std::includes(values_x.begin(), values_x.end(), values_y.begin(),
1347+
values_y.end()) == status;
1348+
}
1349+
1350+
bool CheckSetDiff(
1351+
const Constraint& ct, const std::function<int64_t(Variable*)>& evaluator,
1352+
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
1353+
const std::vector<int64_t> values_x = SetEval(ct.arguments[0], set_evaluator);
1354+
const std::vector<int64_t> values_y = SetEval(ct.arguments[1], set_evaluator);
1355+
const std::vector<int64_t> values_r = SetEval(ct.arguments[2], set_evaluator);
1356+
absl::flat_hash_set<int64_t> set_x(values_x.begin(), values_x.end());
1357+
absl::flat_hash_set<int64_t> set_y(values_y.begin(), values_y.end());
1358+
absl::flat_hash_set<int64_t> set_r(values_r.begin(), values_r.end());
1359+
absl::flat_hash_set<int64_t> computed_diff;
1360+
std::set_difference(values_x.begin(), values_x.end(), values_y.begin(),
1361+
values_y.end(),
1362+
std::inserter(computed_diff, computed_diff.begin()));
1363+
return computed_diff == set_r;
1364+
}
1365+
1366+
bool CheckSetSymDiff(
1367+
const Constraint& ct, const std::function<int64_t(Variable*)>& evaluator,
1368+
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
1369+
const std::vector<int64_t> values_x = SetEval(ct.arguments[0], set_evaluator);
1370+
const std::vector<int64_t> values_y = SetEval(ct.arguments[1], set_evaluator);
1371+
const std::vector<int64_t> values_r = SetEval(ct.arguments[2], set_evaluator);
1372+
absl::flat_hash_set<int64_t> set_x(values_x.begin(), values_x.end());
1373+
absl::flat_hash_set<int64_t> set_y(values_y.begin(), values_y.end());
1374+
absl::flat_hash_set<int64_t> set_r(values_r.begin(), values_r.end());
1375+
absl::flat_hash_set<int64_t> computed_sym_diff;
1376+
std::set_symmetric_difference(
1377+
values_x.begin(), values_x.end(), values_y.begin(), values_y.end(),
1378+
std::inserter(computed_sym_diff, computed_sym_diff.begin()));
1379+
return computed_sym_diff == set_r;
1380+
}
1381+
1382+
bool CheckSetEq(
1383+
const Constraint& ct, const std::function<int64_t(Variable*)>& evaluator,
1384+
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
1385+
const std::vector<int64_t> values_x = SetEval(ct.arguments[0], set_evaluator);
1386+
const std::vector<int64_t> values_y = SetEval(ct.arguments[1], set_evaluator);
1387+
return values_x == values_y;
1388+
}
1389+
1390+
bool CheckSetEqReif(
1391+
const Constraint& ct, const std::function<int64_t(Variable*)>& evaluator,
1392+
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
1393+
const std::vector<int64_t> values_x = SetEval(ct.arguments[0], set_evaluator);
1394+
const std::vector<int64_t> values_y = SetEval(ct.arguments[1], set_evaluator);
1395+
const bool status = Eval(ct.arguments[2], evaluator) != 0;
1396+
return (values_x == values_y) == status;
1397+
}
1398+
1399+
bool CheckSetNe(
1400+
const Constraint& ct, const std::function<int64_t(Variable*)>& evaluator,
1401+
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
1402+
const std::vector<int64_t> values_x = SetEval(ct.arguments[0], set_evaluator);
1403+
const std::vector<int64_t> values_y = SetEval(ct.arguments[1], set_evaluator);
1404+
return values_x != values_y;
1405+
}
1406+
1407+
bool CheckSetNeReif(
1408+
const Constraint& ct, const std::function<int64_t(Variable*)>& evaluator,
1409+
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
1410+
const std::vector<int64_t> values_x = SetEval(ct.arguments[0], set_evaluator);
1411+
const std::vector<int64_t> values_y = SetEval(ct.arguments[1], set_evaluator);
1412+
const bool status = Eval(ct.arguments[2], evaluator) != 0;
1413+
return (values_x != values_y) == status;
1414+
}
1415+
12391416
bool CheckSlidingSum(
12401417
const Constraint& ct, const std::function<int64_t(Variable*)>& evaluator,
12411418
const std::function<std::vector<int64_t>(Variable*)>& set_evaluator) {
@@ -1361,8 +1538,10 @@ CallMap CreateCallMap() {
13611538
m["array_int_element_nonshifted"] = CheckArrayIntElementNonShifted;
13621539
m["array_int_maximum"] = CheckMaximumInt;
13631540
m["array_int_minimum"] = CheckMinimumInt;
1541+
m["array_set_element"] = CheckArraySetElement;
13641542
m["array_var_bool_element"] = CheckArrayVarIntElement;
13651543
m["array_var_int_element"] = CheckArrayVarIntElement;
1544+
m["array_var_set_element"] = CheckArraySetElement;
13661545
m["at_most_int"] = CheckAtMostInt;
13671546
m["bool_and"] = CheckBoolAnd;
13681547
m["bool_clause"] = CheckBoolClause;
@@ -1483,9 +1662,21 @@ CallMap CreateCallMap() {
14831662
m["ortools_table_int"] = CheckTableInt;
14841663
m["regular_nfa"] = CheckRegularNfa;
14851664
m["set_card"] = CheckSetCard;
1665+
m["set_diff"] = CheckSetDiff;
1666+
m["set_eq_reif"] = CheckSetEqReif;
1667+
m["set_eq"] = CheckSetEq;
14861668
m["set_in_reif"] = CheckSetInReif;
14871669
m["set_in"] = CheckSetIn;
1670+
m["set_intersect"] = CheckSetIntersect;
1671+
m["set_ne_reif"] = CheckSetNeReif;
1672+
m["set_ne"] = CheckSetNe;
14881673
m["set_not_in"] = CheckSetNotIn;
1674+
m["set_subset_reif"] = CheckSetSubsetReif;
1675+
m["set_subset"] = CheckSetSubset;
1676+
m["set_superset_reif"] = CheckSetSupersetReif;
1677+
m["set_superset"] = CheckSetSuperset;
1678+
m["set_symdiff"] = CheckSetSymDiff;
1679+
m["set_union"] = CheckSetUnion;
14891680
m["sliding_sum"] = CheckSlidingSum;
14901681
m["sort"] = CheckSort;
14911682
m["symmetric_all_different"] = CheckSymmetricAllDifferent;

0 commit comments

Comments
 (0)