Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit a09c380

Browse files
authored
[MLIR] [python] A few improvements to the Python bindings (#131686)
* `PyRegionList` is now sliceable. The dialect bindings generator seems to assume it is sliceable already (!), yet accessing e.g. `cases` on `scf.IndexedSwitchOp` raises a `TypeError` at runtime. * `PyBlockList` and `PyOperationList` support negative indexing. It is common for containers to do that in Python, and most container in the MLIR Python bindings already allow the index to be negative.
1 parent 3650161 commit a09c380

File tree

1 file changed

+33
-16
lines changed

1 file changed

+33
-16
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -361,37 +361,45 @@ class PyRegionIterator {
361361

362362
/// Regions of an op are fixed length and indexed numerically so are represented
363363
/// with a sequence-like container.
364-
class PyRegionList {
364+
class PyRegionList : public Sliceable<PyRegionList, PyRegion> {
365365
public:
366-
PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
366+
static constexpr const char *pyClassName = "RegionSequence";
367+
368+
PyRegionList(PyOperationRef operation, intptr_t startIndex = 0,
369+
intptr_t length = -1, intptr_t step = 1)
370+
: Sliceable(startIndex,
371+
length == -1 ? mlirOperationGetNumRegions(operation->get())
372+
: length,
373+
step),
374+
operation(std::move(operation)) {}
367375

368376
PyRegionIterator dunderIter() {
369377
operation->checkValid();
370378
return PyRegionIterator(operation);
371379
}
372380

373-
intptr_t dunderLen() {
381+
static void bindDerived(ClassTy &c) {
382+
c.def("__iter__", &PyRegionList::dunderIter);
383+
}
384+
385+
private:
386+
/// Give the parent CRTP class access to hook implementations below.
387+
friend class Sliceable<PyRegionList, PyRegion>;
388+
389+
intptr_t getRawNumElements() {
374390
operation->checkValid();
375391
return mlirOperationGetNumRegions(operation->get());
376392
}
377393

378-
PyRegion dunderGetItem(intptr_t index) {
379-
// dunderLen checks validity.
380-
if (index < 0 || index >= dunderLen()) {
381-
throw nb::index_error("attempt to access out of bounds region");
382-
}
383-
MlirRegion region = mlirOperationGetRegion(operation->get(), index);
384-
return PyRegion(operation, region);
394+
PyRegion getRawElement(intptr_t pos) {
395+
operation->checkValid();
396+
return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos));
385397
}
386398

387-
static void bind(nb::module_ &m) {
388-
nb::class_<PyRegionList>(m, "RegionSequence")
389-
.def("__len__", &PyRegionList::dunderLen)
390-
.def("__iter__", &PyRegionList::dunderIter)
391-
.def("__getitem__", &PyRegionList::dunderGetItem);
399+
PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
400+
return PyRegionList(operation, startIndex, length, step);
392401
}
393402

394-
private:
395403
PyOperationRef operation;
396404
};
397405

@@ -450,6 +458,9 @@ class PyBlockList {
450458

451459
PyBlock dunderGetItem(intptr_t index) {
452460
operation->checkValid();
461+
if (index < 0) {
462+
index += dunderLen();
463+
}
453464
if (index < 0) {
454465
throw nb::index_error("attempt to access out of bounds block");
455466
}
@@ -546,6 +557,9 @@ class PyOperationList {
546557

547558
nb::object dunderGetItem(intptr_t index) {
548559
parentOperation->checkValid();
560+
if (index < 0) {
561+
index += dunderLen();
562+
}
549563
if (index < 0) {
550564
throw nb::index_error("attempt to access out of bounds operation");
551565
}
@@ -2629,6 +2643,9 @@ class PyOpAttributeMap {
26292643
}
26302644

26312645
PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2646+
if (index < 0) {
2647+
index += dunderLen();
2648+
}
26322649
if (index < 0 || index >= dunderLen()) {
26332650
throw nb::index_error("attempt to access out of bounds attribute");
26342651
}

0 commit comments

Comments
 (0)