|
15 | 15 | #include "mlir-c/Bindings/Python/Interop.h"
|
16 | 16 | #include "mlir-c/BuiltinAttributes.h"
|
17 | 17 | #include "mlir-c/BuiltinTypes.h"
|
| 18 | +#include "mlir-c/IntegerSet.h" |
18 | 19 | #include "mlir-c/Registration.h"
|
19 | 20 | #include "llvm/ADT/SmallVector.h"
|
20 | 21 | #include <pybind11/stl.h>
|
@@ -3331,6 +3332,102 @@ PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
|
3331 | 3332 | rawAffineMap);
|
3332 | 3333 | }
|
3333 | 3334 |
|
| 3335 | +//------------------------------------------------------------------------------ |
| 3336 | +// PyIntegerSet and utilities. |
| 3337 | +//------------------------------------------------------------------------------ |
| 3338 | + |
| 3339 | +class PyIntegerSetConstraint { |
| 3340 | +public: |
| 3341 | + PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} |
| 3342 | + |
| 3343 | + PyAffineExpr getExpr() { |
| 3344 | + return PyAffineExpr(set.getContext(), |
| 3345 | + mlirIntegerSetGetConstraint(set, pos)); |
| 3346 | + } |
| 3347 | + |
| 3348 | + bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } |
| 3349 | + |
| 3350 | + static void bind(py::module &m) { |
| 3351 | + py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint") |
| 3352 | + .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) |
| 3353 | + .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); |
| 3354 | + } |
| 3355 | + |
| 3356 | +private: |
| 3357 | + PyIntegerSet set; |
| 3358 | + intptr_t pos; |
| 3359 | +}; |
| 3360 | + |
| 3361 | +class PyIntegerSetConstraintList |
| 3362 | + : public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> { |
| 3363 | +public: |
| 3364 | + static constexpr const char *pyClassName = "IntegerSetConstraintList"; |
| 3365 | + |
| 3366 | + PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, |
| 3367 | + intptr_t length = -1, intptr_t step = 1) |
| 3368 | + : Sliceable(startIndex, |
| 3369 | + length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, |
| 3370 | + step), |
| 3371 | + set(set) {} |
| 3372 | + |
| 3373 | + intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } |
| 3374 | + |
| 3375 | + PyIntegerSetConstraint getElement(intptr_t pos) { |
| 3376 | + return PyIntegerSetConstraint(set, pos); |
| 3377 | + } |
| 3378 | + |
| 3379 | + PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, |
| 3380 | + intptr_t step) { |
| 3381 | + return PyIntegerSetConstraintList(set, startIndex, length, step); |
| 3382 | + } |
| 3383 | + |
| 3384 | +private: |
| 3385 | + PyIntegerSet set; |
| 3386 | +}; |
| 3387 | + |
| 3388 | +bool PyIntegerSet::operator==(const PyIntegerSet &other) { |
| 3389 | + return mlirIntegerSetEqual(integerSet, other.integerSet); |
| 3390 | +} |
| 3391 | + |
| 3392 | +py::object PyIntegerSet::getCapsule() { |
| 3393 | + return py::reinterpret_steal<py::object>( |
| 3394 | + mlirPythonIntegerSetToCapsule(*this)); |
| 3395 | +} |
| 3396 | + |
| 3397 | +PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { |
| 3398 | + MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); |
| 3399 | + if (mlirIntegerSetIsNull(rawIntegerSet)) |
| 3400 | + throw py::error_already_set(); |
| 3401 | + return PyIntegerSet( |
| 3402 | + PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), |
| 3403 | + rawIntegerSet); |
| 3404 | +} |
| 3405 | + |
| 3406 | +/// Attempts to populate `result` with the content of `list` casted to the |
| 3407 | +/// appropriate type (Python and C types are provided as template arguments). |
| 3408 | +/// Throws errors in case of failure, using "action" to describe what the caller |
| 3409 | +/// was attempting to do. |
| 3410 | +template <typename PyType, typename CType> |
| 3411 | +static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result, |
| 3412 | + StringRef action) { |
| 3413 | + result.reserve(py::len(list)); |
| 3414 | + for (py::handle item : list) { |
| 3415 | + try { |
| 3416 | + result.push_back(item.cast<PyType>()); |
| 3417 | + } catch (py::cast_error &err) { |
| 3418 | + std::string msg = (llvm::Twine("Invalid expression when ") + action + |
| 3419 | + " (" + err.what() + ")") |
| 3420 | + .str(); |
| 3421 | + throw py::cast_error(msg); |
| 3422 | + } catch (py::reference_cast_error &err) { |
| 3423 | + std::string msg = (llvm::Twine("Invalid expression (None?) when ") + |
| 3424 | + action + " (" + err.what() + ")") |
| 3425 | + .str(); |
| 3426 | + throw py::cast_error(msg); |
| 3427 | + } |
| 3428 | + } |
| 3429 | +} |
| 3430 | + |
3334 | 3431 | //------------------------------------------------------------------------------
|
3335 | 3432 | // Populates the pybind11 IR submodule.
|
3336 | 3433 | //------------------------------------------------------------------------------
|
@@ -4152,24 +4249,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
4152 | 4249 | [](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
|
4153 | 4250 | DefaultingPyMlirContext context) {
|
4154 | 4251 | SmallVector<MlirAffineExpr> affineExprs;
|
4155 |
| - affineExprs.reserve(py::len(exprs)); |
4156 |
| - for (py::handle expr : exprs) { |
4157 |
| - try { |
4158 |
| - affineExprs.push_back(expr.cast<PyAffineExpr>()); |
4159 |
| - } catch (py::cast_error &err) { |
4160 |
| - std::string msg = |
4161 |
| - std::string("Invalid expression when attempting to create " |
4162 |
| - "an AffineMap (") + |
4163 |
| - err.what() + ")"; |
4164 |
| - throw py::cast_error(msg); |
4165 |
| - } catch (py::reference_cast_error &err) { |
4166 |
| - std::string msg = |
4167 |
| - std::string("Invalid expression (None?) when attempting to " |
4168 |
| - "create an AffineMap (") + |
4169 |
| - err.what() + ")"; |
4170 |
| - throw py::cast_error(msg); |
4171 |
| - } |
4172 |
| - } |
| 4252 | + pyListToVector<PyAffineExpr, MlirAffineExpr>( |
| 4253 | + exprs, affineExprs, "attempting to create an AffineMap"); |
4173 | 4254 | MlirAffineMap map =
|
4174 | 4255 | mlirAffineMapGet(context->get(), dimCount, symbolCount,
|
4175 | 4256 | affineExprs.size(), affineExprs.data());
|
@@ -4275,4 +4356,125 @@ void mlir::python::populateIRSubmodule(py::module &m) {
|
4275 | 4356 | return PyAffineMapExprList(self);
|
4276 | 4357 | });
|
4277 | 4358 | PyAffineMapExprList::bind(m);
|
| 4359 | + |
| 4360 | + //---------------------------------------------------------------------------- |
| 4361 | + // Mapping of PyIntegerSet. |
| 4362 | + //---------------------------------------------------------------------------- |
| 4363 | + py::class_<PyIntegerSet>(m, "IntegerSet") |
| 4364 | + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
| 4365 | + &PyIntegerSet::getCapsule) |
| 4366 | + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) |
| 4367 | + .def("__eq__", [](PyIntegerSet &self, |
| 4368 | + PyIntegerSet &other) { return self == other; }) |
| 4369 | + .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) |
| 4370 | + .def("__str__", |
| 4371 | + [](PyIntegerSet &self) { |
| 4372 | + PyPrintAccumulator printAccum; |
| 4373 | + mlirIntegerSetPrint(self, printAccum.getCallback(), |
| 4374 | + printAccum.getUserData()); |
| 4375 | + return printAccum.join(); |
| 4376 | + }) |
| 4377 | + .def("__repr__", |
| 4378 | + [](PyIntegerSet &self) { |
| 4379 | + PyPrintAccumulator printAccum; |
| 4380 | + printAccum.parts.append("IntegerSet("); |
| 4381 | + mlirIntegerSetPrint(self, printAccum.getCallback(), |
| 4382 | + printAccum.getUserData()); |
| 4383 | + printAccum.parts.append(")"); |
| 4384 | + return printAccum.join(); |
| 4385 | + }) |
| 4386 | + .def_property_readonly( |
| 4387 | + "context", |
| 4388 | + [](PyIntegerSet &self) { return self.getContext().getObject(); }) |
| 4389 | + .def( |
| 4390 | + "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, |
| 4391 | + kDumpDocstring) |
| 4392 | + .def_static( |
| 4393 | + "get", |
| 4394 | + [](intptr_t numDims, intptr_t numSymbols, py::list exprs, |
| 4395 | + std::vector<bool> eqFlags, DefaultingPyMlirContext context) { |
| 4396 | + if (exprs.size() != eqFlags.size()) |
| 4397 | + throw py::value_error( |
| 4398 | + "Expected the number of constraints to match " |
| 4399 | + "that of equality flags"); |
| 4400 | + if (exprs.empty()) |
| 4401 | + throw py::value_error("Expected non-empty list of constraints"); |
| 4402 | + |
| 4403 | + // Copy over to a SmallVector because std::vector has a |
| 4404 | + // specialization for booleans that packs data and does not |
| 4405 | + // expose a `bool *`. |
| 4406 | + SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end()); |
| 4407 | + |
| 4408 | + SmallVector<MlirAffineExpr> affineExprs; |
| 4409 | + pyListToVector<PyAffineExpr>(exprs, affineExprs, |
| 4410 | + "attempting to create an IntegerSet"); |
| 4411 | + MlirIntegerSet set = mlirIntegerSetGet( |
| 4412 | + context->get(), numDims, numSymbols, exprs.size(), |
| 4413 | + affineExprs.data(), flags.data()); |
| 4414 | + return PyIntegerSet(context->getRef(), set); |
| 4415 | + }, |
| 4416 | + py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), |
| 4417 | + py::arg("eq_flags"), py::arg("context") = py::none()) |
| 4418 | + .def_static( |
| 4419 | + "get_empty", |
| 4420 | + [](intptr_t numDims, intptr_t numSymbols, |
| 4421 | + DefaultingPyMlirContext context) { |
| 4422 | + MlirIntegerSet set = |
| 4423 | + mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); |
| 4424 | + return PyIntegerSet(context->getRef(), set); |
| 4425 | + }, |
| 4426 | + py::arg("num_dims"), py::arg("num_symbols"), |
| 4427 | + py::arg("context") = py::none()) |
| 4428 | + .def("get_replaced", |
| 4429 | + [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, |
| 4430 | + intptr_t numResultDims, intptr_t numResultSymbols) { |
| 4431 | + if (static_cast<intptr_t>(dimExprs.size()) != |
| 4432 | + mlirIntegerSetGetNumDims(self)) |
| 4433 | + throw py::value_error( |
| 4434 | + "Expected the number of dimension replacement expressions " |
| 4435 | + "to match that of dimensions"); |
| 4436 | + if (static_cast<intptr_t>(symbolExprs.size()) != |
| 4437 | + mlirIntegerSetGetNumSymbols(self)) |
| 4438 | + throw py::value_error( |
| 4439 | + "Expected the number of symbol replacement expressions " |
| 4440 | + "to match that of symbols"); |
| 4441 | + |
| 4442 | + SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs; |
| 4443 | + pyListToVector<PyAffineExpr>( |
| 4444 | + dimExprs, dimAffineExprs, |
| 4445 | + "attempting to create an IntegerSet by replacing dimensions"); |
| 4446 | + pyListToVector<PyAffineExpr>( |
| 4447 | + symbolExprs, symbolAffineExprs, |
| 4448 | + "attempting to create an IntegerSet by replacing symbols"); |
| 4449 | + MlirIntegerSet set = mlirIntegerSetReplaceGet( |
| 4450 | + self, dimAffineExprs.data(), symbolAffineExprs.data(), |
| 4451 | + numResultDims, numResultSymbols); |
| 4452 | + return PyIntegerSet(self.getContext(), set); |
| 4453 | + }) |
| 4454 | + .def_property_readonly("is_canonical_empty", |
| 4455 | + [](PyIntegerSet &self) { |
| 4456 | + return mlirIntegerSetIsCanonicalEmpty(self); |
| 4457 | + }) |
| 4458 | + .def_property_readonly( |
| 4459 | + "n_dims", |
| 4460 | + [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) |
| 4461 | + .def_property_readonly( |
| 4462 | + "n_symbols", |
| 4463 | + [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) |
| 4464 | + .def_property_readonly( |
| 4465 | + "n_inputs", |
| 4466 | + [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) |
| 4467 | + .def_property_readonly("n_equalities", |
| 4468 | + [](PyIntegerSet &self) { |
| 4469 | + return mlirIntegerSetGetNumEqualities(self); |
| 4470 | + }) |
| 4471 | + .def_property_readonly("n_inequalities", |
| 4472 | + [](PyIntegerSet &self) { |
| 4473 | + return mlirIntegerSetGetNumInequalities(self); |
| 4474 | + }) |
| 4475 | + .def_property_readonly("constraints", [](PyIntegerSet &self) { |
| 4476 | + return PyIntegerSetConstraintList(self); |
| 4477 | + }); |
| 4478 | + PyIntegerSetConstraint::bind(m); |
| 4479 | + PyIntegerSetConstraintList::bind(m); |
4278 | 4480 | }
|
0 commit comments