From 7b233f77d7a0786aec27db1a65cecd85ee1e9955 Mon Sep 17 00:00:00 2001 From: makslevental Date: Wed, 3 Dec 2025 12:22:06 -0800 Subject: [PATCH] [MLIR][Python] make Sliceable inherit from Sequence --- mlir/lib/Bindings/Python/NanobindUtils.h | 12 +++++++++++- mlir/test/python/ir/operation.py | 4 ++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h index 658e8ad5330ef..aea195fecae82 100644 --- a/mlir/lib/Bindings/Python/NanobindUtils.h +++ b/mlir/lib/Bindings/Python/NanobindUtils.h @@ -19,6 +19,7 @@ #include "llvm/Support/raw_ostream.h" #include +#include #include template <> @@ -344,7 +345,16 @@ class Sliceable { /// Binds the indexing and length methods in the Python class. static void bind(nanobind::module_ &m) { - auto clazz = nanobind::class_(m, Derived::pyClassName) + const std::type_info &elemTy = typeid(ElementTy); + PyObject *elemTyInfo = nanobind::detail::nb_type_lookup(&elemTy); + assert(elemTyInfo && + "expected nb_type_lookup to succeed for Sliceable elemTy"); + nanobind::handle elemTyName = nanobind::detail::nb_type_name(elemTyInfo); + std::string sig = std::string("class ") + Derived::pyClassName + + "(collections.abc.Sequence[" + + nanobind::cast(elemTyName) + "])"; + auto clazz = nanobind::class_(m, Derived::pyClassName, + nanobind::sig(sig.c_str())) .def("__add__", &Sliceable::dunderAdd); Derived::bindDerived(clazz); diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index ca99c2a985242..d124c284197b8 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -43,6 +43,10 @@ def testTraverseOpRegionBlockIterators(): ) op = module.operation assert op.context is ctx + # Note, __nb_signature__ stores the fully-qualified signature - the actual type stub emitted is + # class RegionSequence(Sequence[Region]) + # CHECK: class RegionSequence(collections.abc.Sequence[mlir._mlir_libs._mlir.ir.Region]) + print(RegionSequence.__nb_signature__) # Get the block using iterators off of the named collections. regions = list(op.regions[:]) blocks = list(regions[0].blocks)