Skip to content

Conversation

@superbobry
Copy link
Contributor

  • PyRegionList is now sliceable. The dialect bindings generator seems to assume it is sliceable already (!), yet accessing e.g. cases on scf.IndexedSwitchOp raises a TypeError at runtime.
  • PyBlockList and PyOperationList support negative indexing. It is common for containers to do that in Python, and most container in the MLIR Python bindings already allow the index to be negative.

@llvmbot
Copy link
Member

llvmbot commented Mar 17, 2025

@llvm/pr-subscribers-mlir

Author: Sergei Lebedev (superbobry)

Changes
  • PyRegionList is now sliceable. The dialect bindings generator seems to assume it is sliceable already (!), yet accessing e.g. cases on scf.IndexedSwitchOp raises a TypeError at runtime.
  • PyBlockList and PyOperationList support negative indexing. It is common for containers to do that in Python, and most container in the MLIR Python bindings already allow the index to be negative.

Full diff: https://github.com/llvm/llvm-project/pull/131686.diff

3 Files Affected:

  • (modified) mlir/lib/Bindings/Python/IRCore.cpp (+33-16)
  • (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+3)
  • (modified) mlir/test/python/ir/operation.py (+3-3)
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 12793f7dd15be..dc41aaea3261c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -361,37 +361,45 @@ class PyRegionIterator {
 
 /// Regions of an op are fixed length and indexed numerically so are represented
 /// with a sequence-like container.
-class PyRegionList {
+class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
 public:
-  PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
+  static constexpr const char *pyClassName = "RegionSequence";
+
+  PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
+               intptr_t length = -1, intptr_t step = 1)
+      : Sliceable(startIndex,
+                  length == -1 ? mlirOperationGetNumRegions(operation->get())
+                               : length,
+                  step),
+        operation(std::move(operation)) {}
 
   PyRegionIterator dunderIter() {
     operation->checkValid();
     return PyRegionIterator(operation);
   }
 
-  intptr_t dunderLen() {
+  static void bindDerived(ClassTy &c) {
+    c.def("__iter__", &PyRegionList::dunderIter);
+  }
+
+private:
+  /// Give the parent CRTP class access to hook implementations below.
+  friend class Sliceable<PyRegionList, PyRegion>;
+
+  intptr_t getRawNumElements() {
     operation->checkValid();
     return mlirOperationGetNumRegions(operation->get());
   }
 
-  PyRegion dunderGetItem(intptr_t index) {
-    // dunderLen checks validity.
-    if (index < 0 || index >= dunderLen()) {
-      throw nb::index_error("attempt to access out of bounds region");
-    }
-    MlirRegion region = mlirOperationGetRegion(operation->get(), index);
-    return PyRegion(operation, region);
+  PyRegion getRawElement(intptr_t pos) {
+    operation->checkValid();
+    return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
   }
 
-  static void bind(nb::module_ &m) {
-    nb::class_<PyRegionList>(m, "RegionSequence")
-        .def("__len__", &PyRegionList::dunderLen)
-        .def("__iter__", &PyRegionList::dunderIter)
-        .def("__getitem__", &PyRegionList::dunderGetItem);
+  PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
+    return PyRegionList(operation, startIndex, length, step);
   }
 
-private:
   PyOperationRef operation;
 };
 
@@ -450,6 +458,9 @@ class PyBlockList {
 
   PyBlock dunderGetItem(intptr_t index) {
     operation->checkValid();
+    if (index < 0) {
+      index += dunderLen();
+    }
     if (index < 0) {
       throw nb::index_error("attempt to access out of bounds block");
     }
@@ -546,6 +557,9 @@ class PyOperationList {
 
   nb::object dunderGetItem(intptr_t index) {
     parentOperation->checkValid();
+    if (index < 0) {
+      index += dunderLen();
+    }
     if (index < 0) {
       throw nb::index_error("attempt to access out of bounds operation");
     }
@@ -2629,6 +2643,9 @@ class PyOpAttributeMap {
   }
 
   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
+    if (index < 0) {
+      index += dunderLen();
+    }
     if (index < 0 || index >= dunderLen()) {
       throw nb::index_error("attempt to access out of bounds attribute");
     }
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index c93de2fe3154e..c60ff72ff9fd4 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -2466,7 +2466,10 @@ class RegionIterator:
     def __next__(self) -> Region: ...
 
 class RegionSequence:
+    @overload
     def __getitem__(self, arg0: int) -> Region: ...
+    @overload
+    def __getitem__(self, arg0: slice) -> Sequence[Region]: ...
     def __iter__(self) -> RegionIterator: ...
     def __len__(self) -> int: ...
 
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index dd2731ba2e1f1..8040ef4a01703 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -44,7 +44,7 @@ def testTraverseOpRegionBlockIterators():
     op = module.operation
     assert op.context is ctx
     # Get the block using iterators off of the named collections.
-    regions = list(op.regions)
+    regions = list(op.regions[:])
     blocks = list(regions[0].blocks)
     # CHECK: MODULE REGIONS=1 BLOCKS=1
     print(f"MODULE REGIONS={len(regions)} BLOCKS={len(blocks)}")
@@ -86,8 +86,8 @@ def walk_operations(indent, op):
     # CHECK:     Block iter: <mlir.{{.+}}.BlockIterator
     # CHECK: Operation iter: <mlir.{{.+}}.OperationIterator
     print("   Region iter:", iter(op.regions))
-    print("    Block iter:", iter(op.regions[0]))
-    print("Operation iter:", iter(op.regions[0].blocks[0]))
+    print("    Block iter:", iter(op.regions[-1]))
+    print("Operation iter:", iter(op.regions[-1].blocks[-1]))
 
 
 # Verify index based traversal of the op/region/block hierarchy.

Comment on lines 464 to 466
if (index < 0) {
throw nb::index_error("attempt to access out of bounds block");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DCE?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can still be negative, right?

xs = [1, 2, 3]
xs[-42]

Copy link
Contributor

@makslevental makslevental Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh whoops duh. Ok ignore me then!

Copy link
Collaborator

@joker-eph joker-eph Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xs = [1, 2, 3]
xs[-42]

Do we have this already in the test suite?

Out-of-scope for your PR here, but could these exceptions be a bit more expressive by including the original index, the adjusted one, and the bounds in the error message?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test suite doesn't check OOB errors at the moment, I think.

Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me (modulo nits). Thanks!

@joker-eph joker-eph changed the title [MLIR] [PYTHON] A few improvements to the Python bindings [MLIR] [python] A few improvements to the Python bindings Mar 18, 2025
@superbobry
Copy link
Contributor Author

Can we :shipit: or should I add more test cases?

@makslevental
Copy link
Contributor

Can we :shipit: or should I add more test cases?

Dealer's choice; you can add x[-42] or someone can do it in a follow-up. Up to you.

* `PyRegionList` is now sliceable. The dialect bindings generator seems to
  assume it is sliceable already (!), yet accessing e.g. `cases` on
  `scf.IndexedSwitchOp` raises a `TypeError` at runtime.
* `PyBlockList` and `PyOperationList` support negative indexing. It is common
  for containers to do that in Python, and most container in the MLIR Python
  bindings already allow the index to be negative.
@superbobry superbobry force-pushed the piper_export_cl_737773506 branch from fda98f4 to c274951 Compare March 20, 2025 11:49
@superbobry
Copy link
Contributor Author

Okay, all done. Ready to merge.

@makslevental makslevental merged commit c8a9a41 into llvm:main Mar 21, 2025
11 checks passed
@superbobry superbobry deleted the piper_export_cl_737773506 branch March 21, 2025 12:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:python MLIR Python bindings mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants