Skip to content

Commit e6d1dea

Browse files
authored
[MLIR][Python] make Sliceable inherit from Sequence (#170551)
Generates type stubs like ```python class RegionSequence(Sequence[Region]): def __add__(self, arg: RegionSequence, /) -> list[Region]: ... def __iter__(self) -> RegionIterator: """Returns an iterator over the regions in the sequence.""" ```
1 parent 0ecac6d commit e6d1dea

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

mlir/lib/Bindings/Python/NanobindUtils.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "llvm/Support/raw_ostream.h"
2020

2121
#include <string>
22+
#include <typeinfo>
2223
#include <variant>
2324

2425
template <>
@@ -344,7 +345,16 @@ class Sliceable {
344345

345346
/// Binds the indexing and length methods in the Python class.
346347
static void bind(nanobind::module_ &m) {
347-
auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName)
348+
const std::type_info &elemTy = typeid(ElementTy);
349+
PyObject *elemTyInfo = nanobind::detail::nb_type_lookup(&elemTy);
350+
assert(elemTyInfo &&
351+
"expected nb_type_lookup to succeed for Sliceable elemTy");
352+
nanobind::handle elemTyName = nanobind::detail::nb_type_name(elemTyInfo);
353+
std::string sig = std::string("class ") + Derived::pyClassName +
354+
"(collections.abc.Sequence[" +
355+
nanobind::cast<std::string>(elemTyName) + "])";
356+
auto clazz = nanobind::class_<Derived>(m, Derived::pyClassName,
357+
nanobind::sig(sig.c_str()))
348358
.def("__add__", &Sliceable::dunderAdd);
349359
Derived::bindDerived(clazz);
350360

mlir/test/python/ir/operation.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def testTraverseOpRegionBlockIterators():
4343
)
4444
op = module.operation
4545
assert op.context is ctx
46+
# Note, __nb_signature__ stores the fully-qualified signature - the actual type stub emitted is
47+
# class RegionSequence(Sequence[Region])
48+
# CHECK: class RegionSequence(collections.abc.Sequence[mlir._mlir_libs._mlir.ir.Region])
49+
print(RegionSequence.__nb_signature__)
4650
# Get the block using iterators off of the named collections.
4751
regions = list(op.regions[:])
4852
blocks = list(regions[0].blocks)

0 commit comments

Comments
 (0)