-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR] [python] A few improvements to the Python bindings #131686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir Author: Sergei Lebedev (superbobry) Changes
Full diff: https://github.com/llvm/llvm-project/pull/131686.diff 3 Files Affected:
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 12793f7dd15be..dc41aaea3261c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -361,37 +361,45 @@ class PyRegionIterator {
/// Regions of an op are fixed length and indexed numerically so are represented
/// with a sequence-like container.
-class PyRegionList {
+class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
public:
- PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
+ static constexpr const char *pyClassName = "RegionSequence";
+
+ PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
+ intptr_t length = -1, intptr_t step = 1)
+ : Sliceable(startIndex,
+ length == -1 ? mlirOperationGetNumRegions(operation->get())
+ : length,
+ step),
+ operation(std::move(operation)) {}
PyRegionIterator dunderIter() {
operation->checkValid();
return PyRegionIterator(operation);
}
- intptr_t dunderLen() {
+ static void bindDerived(ClassTy &c) {
+ c.def("__iter__", &PyRegionList::dunderIter);
+ }
+
+private:
+ /// Give the parent CRTP class access to hook implementations below.
+ friend class Sliceable<PyRegionList, PyRegion>;
+
+ intptr_t getRawNumElements() {
operation->checkValid();
return mlirOperationGetNumRegions(operation->get());
}
- PyRegion dunderGetItem(intptr_t index) {
- // dunderLen checks validity.
- if (index < 0 || index >= dunderLen()) {
- throw nb::index_error("attempt to access out of bounds region");
- }
- MlirRegion region = mlirOperationGetRegion(operation->get(), index);
- return PyRegion(operation, region);
+ PyRegion getRawElement(intptr_t pos) {
+ operation->checkValid();
+ return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
}
- static void bind(nb::module_ &m) {
- nb::class_<PyRegionList>(m, "RegionSequence")
- .def("__len__", &PyRegionList::dunderLen)
- .def("__iter__", &PyRegionList::dunderIter)
- .def("__getitem__", &PyRegionList::dunderGetItem);
+ PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+ return PyRegionList(operation, startIndex, length, step);
}
-private:
PyOperationRef operation;
};
@@ -450,6 +458,9 @@ class PyBlockList {
PyBlock dunderGetItem(intptr_t index) {
operation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
if (index < 0) {
throw nb::index_error("attempt to access out of bounds block");
}
@@ -546,6 +557,9 @@ class PyOperationList {
nb::object dunderGetItem(intptr_t index) {
parentOperation->checkValid();
+ if (index < 0) {
+ index += dunderLen();
+ }
if (index < 0) {
throw nb::index_error("attempt to access out of bounds operation");
}
@@ -2629,6 +2643,9 @@ class PyOpAttributeMap {
}
PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
+ if (index < 0) {
+ index += dunderLen();
+ }
if (index < 0 || index >= dunderLen()) {
throw nb::index_error("attempt to access out of bounds attribute");
}
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index c93de2fe3154e..c60ff72ff9fd4 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -2466,7 +2466,10 @@ class RegionIterator:
def __next__(self) -> Region: ...
class RegionSequence:
+ @overload
def __getitem__(self, arg0: int) -> Region: ...
+ @overload
+ def __getitem__(self, arg0: slice) -> Sequence[Region]: ...
def __iter__(self) -> RegionIterator: ...
def __len__(self) -> int: ...
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index dd2731ba2e1f1..8040ef4a01703 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -44,7 +44,7 @@ def testTraverseOpRegionBlockIterators():
op = module.operation
assert op.context is ctx
# Get the block using iterators off of the named collections.
- regions = list(op.regions)
+ regions = list(op.regions[:])
blocks = list(regions[0].blocks)
# CHECK: MODULE REGIONS=1 BLOCKS=1
print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
@@ -86,8 +86,8 @@ def walk_operations(indent, op):
# CHECK: Block iter: <mlir.{{.+}}.BlockIterator
# CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
print(" Region iter:", iter(op.regions))
- print(" Block iter:", iter(op.regions[0]))
- print("Operation iter:", iter(op.regions[0].blocks[0]))
+ print(" Block iter:", iter(op.regions[-1]))
+ print("Operation iter:", iter(op.regions[-1].blocks[-1]))
# Verify index based traversal of the op/region/block hierarchy.
|
| if (index < 0) { | ||
| throw nb::index_error("attempt to access out of bounds block"); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DCE?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can still be negative, right?
xs = [1, 2, 3]
xs[-42]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh whoops duh. Ok ignore me then!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xs = [1, 2, 3]
xs[-42]
Do we have this already in the test suite?
Out-of-scope for your PR here, but could these exceptions be a bit more expressive by including the original index, the adjusted one, and the bounds in the error message?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test suite doesn't check OOB errors at the moment, I think.
makslevental
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable to me (modulo nits). Thanks!
|
Can we |
Dealer's choice; you can add |
* `PyRegionList` is now sliceable. The dialect bindings generator seems to assume it is sliceable already (!), yet accessing e.g. `cases` on `scf.IndexedSwitchOp` raises a `TypeError` at runtime. * `PyBlockList` and `PyOperationList` support negative indexing. It is common for containers to do that in Python, and most container in the MLIR Python bindings already allow the index to be negative.
fda98f4 to
c274951
Compare
|
Okay, all done. Ready to merge. |
PyRegionListis now sliceable. The dialect bindings generator seems to assume it is sliceable already (!), yet accessing e.g.casesonscf.IndexedSwitchOpraises aTypeErrorat runtime.PyBlockListandPyOperationListsupport negative indexing. It is common for containers to do that in Python, and most container in the MLIR Python bindings already allow the index to be negative.