Skip to content

Commit 2065419

Browse files
ftynsememfrob
authored andcommitted
[mlir] Add Python bindings for IntegerSet
This follows up on the introduction of C API for the same object and is similar to AffineExpr and AffineMap. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D95437
1 parent 40f32d3 commit 2065419

File tree

4 files changed

+390
-18
lines changed

4 files changed

+390
-18
lines changed

mlir/include/mlir-c/Bindings/Python/Interop.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@
2626
#include "mlir-c/AffineExpr.h"
2727
#include "mlir-c/AffineMap.h"
2828
#include "mlir-c/IR.h"
29+
#include "mlir-c/IntegerSet.h"
2930
#include "mlir-c/Pass.h"
3031

3132
#define MLIR_PYTHON_CAPSULE_AFFINE_EXPR "mlir.ir.AffineExpr._CAPIPtr"
3233
#define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr"
3334
#define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr"
3435
#define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr"
36+
#define MLIR_PYTHON_CAPSULE_INTEGER_SET "mlir.ir.IntegerSet._CAPIPtr"
3537
#define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr"
3638
#define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr"
3739
#define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr"
@@ -240,6 +242,25 @@ static inline MlirAffineMap mlirPythonCapsuleToAffineMap(PyObject *capsule) {
240242
return affineMap;
241243
}
242244

245+
/** Creates a capsule object encapsulating the raw C-API MlirIntegerSet.
246+
* The returned capsule does not extend or affect ownership of any Python
247+
* objects that reference the set in any way. */
248+
static inline PyObject *
249+
mlirPythonIntegerSetToCapsule(MlirIntegerSet integerSet) {
250+
return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(integerSet),
251+
MLIR_PYTHON_CAPSULE_INTEGER_SET, NULL);
252+
}
253+
254+
/** Extracts an MlirIntegerSet from a capsule as produced from
255+
* mlirPythonIntegerSetToCapsule. If the capsule is not of the right type, then
256+
* a null set is returned (as checked via mlirIntegerSetIsNull). In such a
257+
* case, the Python APIs will have already set an error. */
258+
static inline MlirIntegerSet mlirPythonCapsuleToIntegerSet(PyObject *capsule) {
259+
void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_INTEGER_SET);
260+
MlirIntegerSet integerSet = {ptr};
261+
return integerSet;
262+
}
263+
243264
#ifdef __cplusplus
244265
}
245266
#endif

mlir/lib/Bindings/Python/IRModules.cpp

Lines changed: 220 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir-c/Bindings/Python/Interop.h"
1616
#include "mlir-c/BuiltinAttributes.h"
1717
#include "mlir-c/BuiltinTypes.h"
18+
#include "mlir-c/IntegerSet.h"
1819
#include "mlir-c/Registration.h"
1920
#include "llvm/ADT/SmallVector.h"
2021
#include <pybind11/stl.h>
@@ -3331,6 +3332,102 @@ PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) {
33313332
rawAffineMap);
33323333
}
33333334

3335+
//------------------------------------------------------------------------------
3336+
// PyIntegerSet and utilities.
3337+
//------------------------------------------------------------------------------
3338+
3339+
class PyIntegerSetConstraint {
3340+
public:
3341+
PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {}
3342+
3343+
PyAffineExpr getExpr() {
3344+
return PyAffineExpr(set.getContext(),
3345+
mlirIntegerSetGetConstraint(set, pos));
3346+
}
3347+
3348+
bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); }
3349+
3350+
static void bind(py::module &m) {
3351+
py::class_<PyIntegerSetConstraint>(m, "IntegerSetConstraint")
3352+
.def_property_readonly("expr", &PyIntegerSetConstraint::getExpr)
3353+
.def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq);
3354+
}
3355+
3356+
private:
3357+
PyIntegerSet set;
3358+
intptr_t pos;
3359+
};
3360+
3361+
class PyIntegerSetConstraintList
3362+
: public Sliceable<PyIntegerSetConstraintList, PyIntegerSetConstraint> {
3363+
public:
3364+
static constexpr const char *pyClassName = "IntegerSetConstraintList";
3365+
3366+
PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0,
3367+
intptr_t length = -1, intptr_t step = 1)
3368+
: Sliceable(startIndex,
3369+
length == -1 ? mlirIntegerSetGetNumConstraints(set) : length,
3370+
step),
3371+
set(set) {}
3372+
3373+
intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); }
3374+
3375+
PyIntegerSetConstraint getElement(intptr_t pos) {
3376+
return PyIntegerSetConstraint(set, pos);
3377+
}
3378+
3379+
PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length,
3380+
intptr_t step) {
3381+
return PyIntegerSetConstraintList(set, startIndex, length, step);
3382+
}
3383+
3384+
private:
3385+
PyIntegerSet set;
3386+
};
3387+
3388+
bool PyIntegerSet::operator==(const PyIntegerSet &other) {
3389+
return mlirIntegerSetEqual(integerSet, other.integerSet);
3390+
}
3391+
3392+
py::object PyIntegerSet::getCapsule() {
3393+
return py::reinterpret_steal<py::object>(
3394+
mlirPythonIntegerSetToCapsule(*this));
3395+
}
3396+
3397+
PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) {
3398+
MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr());
3399+
if (mlirIntegerSetIsNull(rawIntegerSet))
3400+
throw py::error_already_set();
3401+
return PyIntegerSet(
3402+
PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)),
3403+
rawIntegerSet);
3404+
}
3405+
3406+
/// Attempts to populate `result` with the content of `list` casted to the
3407+
/// appropriate type (Python and C types are provided as template arguments).
3408+
/// Throws errors in case of failure, using "action" to describe what the caller
3409+
/// was attempting to do.
3410+
template <typename PyType, typename CType>
3411+
static void pyListToVector(py::list list, llvm::SmallVectorImpl<CType> &result,
3412+
StringRef action) {
3413+
result.reserve(py::len(list));
3414+
for (py::handle item : list) {
3415+
try {
3416+
result.push_back(item.cast<PyType>());
3417+
} catch (py::cast_error &err) {
3418+
std::string msg = (llvm::Twine("Invalid expression when ") + action +
3419+
" (" + err.what() + ")")
3420+
.str();
3421+
throw py::cast_error(msg);
3422+
} catch (py::reference_cast_error &err) {
3423+
std::string msg = (llvm::Twine("Invalid expression (None?) when ") +
3424+
action + " (" + err.what() + ")")
3425+
.str();
3426+
throw py::cast_error(msg);
3427+
}
3428+
}
3429+
}
3430+
33343431
//------------------------------------------------------------------------------
33353432
// Populates the pybind11 IR submodule.
33363433
//------------------------------------------------------------------------------
@@ -4152,24 +4249,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
41524249
[](intptr_t dimCount, intptr_t symbolCount, py::list exprs,
41534250
DefaultingPyMlirContext context) {
41544251
SmallVector<MlirAffineExpr> affineExprs;
4155-
affineExprs.reserve(py::len(exprs));
4156-
for (py::handle expr : exprs) {
4157-
try {
4158-
affineExprs.push_back(expr.cast<PyAffineExpr>());
4159-
} catch (py::cast_error &err) {
4160-
std::string msg =
4161-
std::string("Invalid expression when attempting to create "
4162-
"an AffineMap (") +
4163-
err.what() + ")";
4164-
throw py::cast_error(msg);
4165-
} catch (py::reference_cast_error &err) {
4166-
std::string msg =
4167-
std::string("Invalid expression (None?) when attempting to "
4168-
"create an AffineMap (") +
4169-
err.what() + ")";
4170-
throw py::cast_error(msg);
4171-
}
4172-
}
4252+
pyListToVector<PyAffineExpr, MlirAffineExpr>(
4253+
exprs, affineExprs, "attempting to create an AffineMap");
41734254
MlirAffineMap map =
41744255
mlirAffineMapGet(context->get(), dimCount, symbolCount,
41754256
affineExprs.size(), affineExprs.data());
@@ -4275,4 +4356,125 @@ void mlir::python::populateIRSubmodule(py::module &m) {
42754356
return PyAffineMapExprList(self);
42764357
});
42774358
PyAffineMapExprList::bind(m);
4359+
4360+
//----------------------------------------------------------------------------
4361+
// Mapping of PyIntegerSet.
4362+
//----------------------------------------------------------------------------
4363+
py::class_<PyIntegerSet>(m, "IntegerSet")
4364+
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
4365+
&PyIntegerSet::getCapsule)
4366+
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule)
4367+
.def("__eq__", [](PyIntegerSet &self,
4368+
PyIntegerSet &other) { return self == other; })
4369+
.def("__eq__", [](PyIntegerSet &self, py::object other) { return false; })
4370+
.def("__str__",
4371+
[](PyIntegerSet &self) {
4372+
PyPrintAccumulator printAccum;
4373+
mlirIntegerSetPrint(self, printAccum.getCallback(),
4374+
printAccum.getUserData());
4375+
return printAccum.join();
4376+
})
4377+
.def("__repr__",
4378+
[](PyIntegerSet &self) {
4379+
PyPrintAccumulator printAccum;
4380+
printAccum.parts.append("IntegerSet(");
4381+
mlirIntegerSetPrint(self, printAccum.getCallback(),
4382+
printAccum.getUserData());
4383+
printAccum.parts.append(")");
4384+
return printAccum.join();
4385+
})
4386+
.def_property_readonly(
4387+
"context",
4388+
[](PyIntegerSet &self) { return self.getContext().getObject(); })
4389+
.def(
4390+
"dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
4391+
kDumpDocstring)
4392+
.def_static(
4393+
"get",
4394+
[](intptr_t numDims, intptr_t numSymbols, py::list exprs,
4395+
std::vector<bool> eqFlags, DefaultingPyMlirContext context) {
4396+
if (exprs.size() != eqFlags.size())
4397+
throw py::value_error(
4398+
"Expected the number of constraints to match "
4399+
"that of equality flags");
4400+
if (exprs.empty())
4401+
throw py::value_error("Expected non-empty list of constraints");
4402+
4403+
// Copy over to a SmallVector because std::vector has a
4404+
// specialization for booleans that packs data and does not
4405+
// expose a `bool *`.
4406+
SmallVector<bool, 8> flags(eqFlags.begin(), eqFlags.end());
4407+
4408+
SmallVector<MlirAffineExpr> affineExprs;
4409+
pyListToVector<PyAffineExpr>(exprs, affineExprs,
4410+
"attempting to create an IntegerSet");
4411+
MlirIntegerSet set = mlirIntegerSetGet(
4412+
context->get(), numDims, numSymbols, exprs.size(),
4413+
affineExprs.data(), flags.data());
4414+
return PyIntegerSet(context->getRef(), set);
4415+
},
4416+
py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"),
4417+
py::arg("eq_flags"), py::arg("context") = py::none())
4418+
.def_static(
4419+
"get_empty",
4420+
[](intptr_t numDims, intptr_t numSymbols,
4421+
DefaultingPyMlirContext context) {
4422+
MlirIntegerSet set =
4423+
mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols);
4424+
return PyIntegerSet(context->getRef(), set);
4425+
},
4426+
py::arg("num_dims"), py::arg("num_symbols"),
4427+
py::arg("context") = py::none())
4428+
.def("get_replaced",
4429+
[](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs,
4430+
intptr_t numResultDims, intptr_t numResultSymbols) {
4431+
if (static_cast<intptr_t>(dimExprs.size()) !=
4432+
mlirIntegerSetGetNumDims(self))
4433+
throw py::value_error(
4434+
"Expected the number of dimension replacement expressions "
4435+
"to match that of dimensions");
4436+
if (static_cast<intptr_t>(symbolExprs.size()) !=
4437+
mlirIntegerSetGetNumSymbols(self))
4438+
throw py::value_error(
4439+
"Expected the number of symbol replacement expressions "
4440+
"to match that of symbols");
4441+
4442+
SmallVector<MlirAffineExpr> dimAffineExprs, symbolAffineExprs;
4443+
pyListToVector<PyAffineExpr>(
4444+
dimExprs, dimAffineExprs,
4445+
"attempting to create an IntegerSet by replacing dimensions");
4446+
pyListToVector<PyAffineExpr>(
4447+
symbolExprs, symbolAffineExprs,
4448+
"attempting to create an IntegerSet by replacing symbols");
4449+
MlirIntegerSet set = mlirIntegerSetReplaceGet(
4450+
self, dimAffineExprs.data(), symbolAffineExprs.data(),
4451+
numResultDims, numResultSymbols);
4452+
return PyIntegerSet(self.getContext(), set);
4453+
})
4454+
.def_property_readonly("is_canonical_empty",
4455+
[](PyIntegerSet &self) {
4456+
return mlirIntegerSetIsCanonicalEmpty(self);
4457+
})
4458+
.def_property_readonly(
4459+
"n_dims",
4460+
[](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); })
4461+
.def_property_readonly(
4462+
"n_symbols",
4463+
[](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); })
4464+
.def_property_readonly(
4465+
"n_inputs",
4466+
[](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); })
4467+
.def_property_readonly("n_equalities",
4468+
[](PyIntegerSet &self) {
4469+
return mlirIntegerSetGetNumEqualities(self);
4470+
})
4471+
.def_property_readonly("n_inequalities",
4472+
[](PyIntegerSet &self) {
4473+
return mlirIntegerSetGetNumInequalities(self);
4474+
})
4475+
.def_property_readonly("constraints", [](PyIntegerSet &self) {
4476+
return PyIntegerSetConstraintList(self);
4477+
});
4478+
PyIntegerSetConstraint::bind(m);
4479+
PyIntegerSetConstraintList::bind(m);
42784480
}

mlir/lib/Bindings/Python/IRModules.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir-c/AffineExpr.h"
1717
#include "mlir-c/AffineMap.h"
1818
#include "mlir-c/IR.h"
19+
#include "mlir-c/IntegerSet.h"
1920
#include "llvm/ADT/DenseMap.h"
2021

2122
namespace mlir {
@@ -726,6 +727,26 @@ class PyAffineMap : public BaseContextObject {
726727
MlirAffineMap affineMap;
727728
};
728729

730+
class PyIntegerSet : public BaseContextObject {
731+
public:
732+
PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet)
733+
: BaseContextObject(std::move(contextRef)), integerSet(integerSet) {}
734+
bool operator==(const PyIntegerSet &other);
735+
operator MlirIntegerSet() const { return integerSet; }
736+
MlirIntegerSet get() const { return integerSet; }
737+
738+
/// Gets a capsule wrapping the void* within the MlirIntegerSet.
739+
pybind11::object getCapsule();
740+
741+
/// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
742+
/// Note that PyIntegerSet instances may be uniqued, so the returned object
743+
/// may be a pre-existing object. Integer sets are owned by the context.
744+
static PyIntegerSet createFromCapsule(pybind11::object capsule);
745+
746+
private:
747+
MlirIntegerSet integerSet;
748+
};
749+
729750
void populateIRSubmodule(pybind11::module &m);
730751

731752
} // namespace python

0 commit comments

Comments
 (0)