Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit 781d025

Browse files
authored
[MLIR][Python] reland (narrower) type stub generation (#157930)
This a reland of llvm/llvm-project#155741 which was reverted at llvm/llvm-project#157831. This version is narrower in scope - it only turns on automatic stub generation for `MLIRPythonExtension.Core._mlir` and **does not do anything automatically**. Specifically, the only CMake code added to `AddMLIRPython.cmake` is the `mlir_generate_type_stubs` function which is then used only in a manual way. The API for `mlir_generate_type_stubs` is: ``` Arguments: MODULE_NAME: The fully-qualified name of the extension module (used for importing in python). DEPENDS_TARGETS: List of targets these type stubs depend on being built; usually corresponding to the specific extension module (e.g., something like StandalonePythonModules.extension._standaloneDialectsNanobind.dso) and the core bindings extension module (e.g., something like StandalonePythonModules.extension._mlir.dso). OUTPUT_DIR: The root output directory to emit the type stubs into. OUTPUTS: List of expected outputs. DEPENDS_TARGET_SRC_DEPS: List of cpp sources for extension library (for generating a DEPFILE). IMPORT_PATHS: List of paths to add to PYTHONPATH for stubgen. PATTERN_FILE: (Optional) Pattern file (see https://nanobind.readthedocs.io/en/latest/typing.html#pattern-files). Outputs: NB_STUBGEN_CUSTOM_TARGET: The target corresponding to generation which other targets can depend on. ``` Downstream users should use `mlir_generate_type_stubs` in coordination with `declare_mlir_python_sources` to turn on stub generation for their own downstream dialect extensions and upstream dialect extensions if they so choose. Standalone example shows an example. Note, downstream will also need to set `-DMLIR_PYTHON_PACKAGE_PREFIX=...` correctly for their bindings.
1 parent e9d2818 commit 781d025

File tree

18 files changed

+541
-3171
lines changed

18 files changed

+541
-3171
lines changed

mlir/include/mlir/Bindings/Python/NanobindAdaptors.h

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -512,8 +512,13 @@ class mlir_attribute_subclass : public pure_subclass {
512512
.attr("replace")(superCls.attr("__name__"), captureTypeName);
513513
});
514514
if (getTypeIDFunction) {
515-
def_staticmethod("get_static_typeid",
516-
[getTypeIDFunction]() { return getTypeIDFunction(); });
515+
def_staticmethod(
516+
"get_static_typeid",
517+
[getTypeIDFunction]() { return getTypeIDFunction(); },
518+
// clang-format off
519+
nanobind::sig("def get_static_typeid() -> " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID"))
520+
// clang-format on
521+
);
517522
nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
518523
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
519524
getTypeIDFunction())(nanobind::cpp_function(
@@ -582,8 +587,9 @@ class mlir_type_subclass : public pure_subclass {
582587

583588
// 'isinstance' method.
584589
static const char kIsinstanceSig[] =
585-
"def isinstance(other_type: " MAKE_MLIR_PYTHON_QUALNAME(
586-
"ir") ".Type) -> bool";
590+
// clang-format off
591+
"def isinstance(other_type: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ") -> bool";
592+
// clang-format on
587593
def_staticmethod(
588594
"isinstance",
589595
[isaFunction](MlirType other) { return isaFunction(other); },
@@ -599,8 +605,13 @@ class mlir_type_subclass : public pure_subclass {
599605
// `def_property_readonly_static` is not available in `pure_subclass` and
600606
// we do not want to introduce the complexity that pybind uses to
601607
// implement it.
602-
def_staticmethod("get_static_typeid",
603-
[getTypeIDFunction]() { return getTypeIDFunction(); });
608+
def_staticmethod(
609+
"get_static_typeid",
610+
[getTypeIDFunction]() { return getTypeIDFunction(); },
611+
// clang-format off
612+
nanobind::sig("def get_static_typeid() -> " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID"))
613+
// clang-format on
614+
);
604615
nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
605616
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
606617
getTypeIDFunction())(nanobind::cpp_function(
@@ -665,8 +676,9 @@ class mlir_value_subclass : public pure_subclass {
665676

666677
// 'isinstance' method.
667678
static const char kIsinstanceSig[] =
668-
"def isinstance(other_value: " MAKE_MLIR_PYTHON_QUALNAME(
669-
"ir") ".Value) -> bool";
679+
// clang-format off
680+
"def isinstance(other_value: " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ") -> bool";
681+
// clang-format on
670682
def_staticmethod(
671683
"isinstance",
672684
[isaFunction](MlirValue other) { return isaFunction(other); },

mlir/lib/Bindings/Python/DialectPDL.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ static void populateDialectPDLSubmodule(const nanobind::module_ &m) {
6868
rangeType.def_property_readonly(
6969
"element_type",
7070
[](MlirType type) { return mlirPDLRangeTypeGetElementType(type); },
71+
nb::sig(
72+
"def element_type(self) -> " MAKE_MLIR_PYTHON_QUALNAME("ir.Type")),
7173
"Get the element type.");
7274

7375
//===-------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 97 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ struct nb_buffer_info {
167167
};
168168

169169
class nb_buffer : public nb::object {
170-
NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer);
170+
NB_OBJECT_DEFAULT(nb_buffer, object, "Buffer", PyObject_CheckBuffer);
171171

172172
nb_buffer_info request() const {
173173
int flags = PyBUF_STRIDES | PyBUF_FORMAT;
@@ -252,8 +252,13 @@ class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
252252
return PyAffineMapAttribute(affineMap.getContext(), attr);
253253
},
254254
nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap.");
255-
c.def_prop_ro("value", mlirAffineMapAttrGetValue,
256-
"Returns the value of the AffineMap attribute");
255+
c.def_prop_ro(
256+
"value",
257+
[](PyAffineMapAttribute &self) {
258+
return PyAffineMap(self.getContext(),
259+
mlirAffineMapAttrGetValue(self));
260+
},
261+
"Returns the value of the AffineMap attribute");
257262
}
258263
};
259264

@@ -480,11 +485,13 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
480485

481486
PyArrayAttributeIterator &dunderIter() { return *this; }
482487

483-
MlirAttribute dunderNext() {
488+
nb::typed<nb::object, PyAttribute> dunderNext() {
484489
// TODO: Throw is an inefficient way to stop iteration.
485490
if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
486491
throw nb::stop_iteration();
487-
return mlirArrayAttrGetElement(attr.get(), nextIndex++);
492+
return PyAttribute(this->attr.getContext(),
493+
mlirArrayAttrGetElement(attr.get(), nextIndex++))
494+
.maybeDownCast();
488495
}
489496

490497
static void bind(nb::module_ &m) {
@@ -517,12 +524,13 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
517524
},
518525
nb::arg("attributes"), nb::arg("context") = nb::none(),
519526
"Gets a uniqued Array attribute");
520-
c.def("__getitem__",
521-
[](PyArrayAttribute &arr, intptr_t i) {
522-
if (i >= mlirArrayAttrGetNumElements(arr))
523-
throw nb::index_error("ArrayAttribute index out of range");
524-
return arr.getItem(i);
525-
})
527+
c.def(
528+
"__getitem__",
529+
[](PyArrayAttribute &arr, intptr_t i) {
530+
if (i >= mlirArrayAttrGetNumElements(arr))
531+
throw nb::index_error("ArrayAttribute index out of range");
532+
return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
533+
})
526534
.def("__len__",
527535
[](const PyArrayAttribute &arr) {
528536
return mlirArrayAttrGetNumElements(arr);
@@ -611,10 +619,12 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
611619
"Returns the value of the integer attribute");
612620
c.def("__int__", toPyInt,
613621
"Converts the value of the integer attribute to a Python int");
614-
c.def_prop_ro_static("static_typeid",
615-
[](nb::object & /*class*/) -> MlirTypeID {
616-
return mlirIntegerAttrGetTypeID();
617-
});
622+
c.def_prop_ro_static(
623+
"static_typeid",
624+
[](nb::object & /*class*/) {
625+
return PyTypeID(mlirIntegerAttrGetTypeID());
626+
},
627+
nanobind::sig("def static_typeid(/) -> TypeID"));
618628
}
619629

620630
private:
@@ -657,8 +667,8 @@ class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
657667
static constexpr const char *pyClassName = "SymbolRefAttr";
658668
using PyConcreteAttribute::PyConcreteAttribute;
659669

660-
static MlirAttribute fromList(const std::vector<std::string> &symbols,
661-
PyMlirContext &context) {
670+
static PySymbolRefAttribute fromList(const std::vector<std::string> &symbols,
671+
PyMlirContext &context) {
662672
if (symbols.empty())
663673
throw std::runtime_error("SymbolRefAttr must be composed of at least "
664674
"one symbol.");
@@ -668,8 +678,10 @@ class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
668678
referenceAttrs.push_back(
669679
mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i])));
670680
}
671-
return mlirSymbolRefAttrGet(context.get(), rootSymbol,
672-
referenceAttrs.size(), referenceAttrs.data());
681+
return PySymbolRefAttribute(context.getRef(),
682+
mlirSymbolRefAttrGet(context.get(), rootSymbol,
683+
referenceAttrs.size(),
684+
referenceAttrs.data()));
673685
}
674686

675687
static void bindDerived(ClassTy &c) {
@@ -746,7 +758,11 @@ class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
746758
return PyOpaqueAttribute(context->getRef(), attr);
747759
},
748760
nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"),
749-
nb::arg("context") = nb::none(), "Gets an Opaque attribute.");
761+
nb::arg("context") = nb::none(),
762+
// clang-format off
763+
nb::sig("def get(dialect_namespace: str, buffer: typing_extensions.Buffer, type: Type, context: Context | None = None) -> OpaqueAttr"),
764+
// clang-format on
765+
"Gets an Opaque attribute.");
750766
c.def_prop_ro(
751767
"dialect_namespace",
752768
[](PyOpaqueAttribute &self) {
@@ -764,59 +780,6 @@ class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
764780
}
765781
};
766782

767-
class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
768-
public:
769-
static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
770-
static constexpr const char *pyClassName = "StringAttr";
771-
using PyConcreteAttribute::PyConcreteAttribute;
772-
static constexpr GetTypeIDFunctionTy getTypeIdFunction =
773-
mlirStringAttrGetTypeID;
774-
775-
static void bindDerived(ClassTy &c) {
776-
c.def_static(
777-
"get",
778-
[](const std::string &value, DefaultingPyMlirContext context) {
779-
MlirAttribute attr =
780-
mlirStringAttrGet(context->get(), toMlirStringRef(value));
781-
return PyStringAttribute(context->getRef(), attr);
782-
},
783-
nb::arg("value"), nb::arg("context") = nb::none(),
784-
"Gets a uniqued string attribute");
785-
c.def_static(
786-
"get",
787-
[](const nb::bytes &value, DefaultingPyMlirContext context) {
788-
MlirAttribute attr =
789-
mlirStringAttrGet(context->get(), toMlirStringRef(value));
790-
return PyStringAttribute(context->getRef(), attr);
791-
},
792-
nb::arg("value"), nb::arg("context") = nb::none(),
793-
"Gets a uniqued string attribute");
794-
c.def_static(
795-
"get_typed",
796-
[](PyType &type, const std::string &value) {
797-
MlirAttribute attr =
798-
mlirStringAttrTypedGet(type, toMlirStringRef(value));
799-
return PyStringAttribute(type.getContext(), attr);
800-
},
801-
nb::arg("type"), nb::arg("value"),
802-
"Gets a uniqued string attribute associated to a type");
803-
c.def_prop_ro(
804-
"value",
805-
[](PyStringAttribute &self) {
806-
MlirStringRef stringRef = mlirStringAttrGetValue(self);
807-
return nb::str(stringRef.data, stringRef.length);
808-
},
809-
"Returns the value of the string attribute");
810-
c.def_prop_ro(
811-
"value_bytes",
812-
[](PyStringAttribute &self) {
813-
MlirStringRef stringRef = mlirStringAttrGetValue(self);
814-
return nb::bytes(stringRef.data, stringRef.length);
815-
},
816-
"Returns the value of the string attribute as `bytes`");
817-
}
818-
};
819-
820783
// TODO: Support construction of string elements.
821784
class PyDenseElementsAttribute
822785
: public PyConcreteAttribute<PyDenseElementsAttribute> {
@@ -1028,11 +991,14 @@ class PyDenseElementsAttribute
1028991
PyDenseElementsAttribute::bf_releasebuffer;
1029992
#endif
1030993
c.def("__len__", &PyDenseElementsAttribute::dunderLen)
1031-
.def_static("get", PyDenseElementsAttribute::getFromBuffer,
1032-
nb::arg("array"), nb::arg("signless") = true,
1033-
nb::arg("type") = nb::none(), nb::arg("shape") = nb::none(),
1034-
nb::arg("context") = nb::none(),
1035-
kDenseElementsAttrGetDocstring)
994+
.def_static(
995+
"get", PyDenseElementsAttribute::getFromBuffer, nb::arg("array"),
996+
nb::arg("signless") = true, nb::arg("type") = nb::none(),
997+
nb::arg("shape") = nb::none(), nb::arg("context") = nb::none(),
998+
// clang-format off
999+
nb::sig("def get(array: typing_extensions.Buffer, signless: bool = True, type: Type | None = None, shape: Sequence[int] | None = None, context: Context | None = None) -> DenseElementsAttr"),
1000+
// clang-format on
1001+
kDenseElementsAttrGetDocstring)
10361002
.def_static("get", PyDenseElementsAttribute::getFromList,
10371003
nb::arg("attrs"), nb::arg("type") = nb::none(),
10381004
nb::arg("context") = nb::none(),
@@ -1048,7 +1014,9 @@ class PyDenseElementsAttribute
10481014
if (!mlirDenseElementsAttrIsSplat(self))
10491015
throw nb::value_error(
10501016
"get_splat_value called on a non-splat attribute");
1051-
return mlirDenseElementsAttrGetSplatValue(self);
1017+
return PyAttribute(self.getContext(),
1018+
mlirDenseElementsAttrGetSplatValue(self))
1019+
.maybeDownCast();
10521020
});
10531021
}
10541022

@@ -1509,6 +1477,9 @@ class PyDenseResourceElementsAttribute
15091477
nb::arg("array"), nb::arg("name"), nb::arg("type"),
15101478
nb::arg("alignment") = nb::none(),
15111479
nb::arg("is_mutable") = false, nb::arg("context") = nb::none(),
1480+
// clang-format off
1481+
nb::sig("def get_from_buffer(array: typing_extensions.Buffer, name: str, type: Type, alignment: int | None = None, is_mutable: bool = False, context: Context | None = None) -> DenseResourceElementsAttr"),
1482+
// clang-format on
15121483
kDenseResourceElementsAttrGetFromBufferDocstring);
15131484
}
15141485
};
@@ -1556,7 +1527,7 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
15561527
mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
15571528
if (mlirAttributeIsNull(attr))
15581529
throw nb::key_error("attempt to access a non-existent attribute");
1559-
return attr;
1530+
return PyAttribute(self.getContext(), attr).maybeDownCast();
15601531
});
15611532
c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
15621533
if (index < 0 || index >= self.dunderLen()) {
@@ -1624,7 +1595,8 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
16241595
nb::arg("value"), nb::arg("context") = nb::none(),
16251596
"Gets a uniqued Type attribute");
16261597
c.def_prop_ro("value", [](PyTypeAttribute &self) {
1627-
return mlirTypeAttrGetValue(self.get());
1598+
return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
1599+
.maybeDownCast();
16281600
});
16291601
}
16301602
};
@@ -1761,6 +1733,50 @@ nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
17611733

17621734
} // namespace
17631735

1736+
void PyStringAttribute::bindDerived(ClassTy &c) {
1737+
c.def_static(
1738+
"get",
1739+
[](const std::string &value, DefaultingPyMlirContext context) {
1740+
MlirAttribute attr =
1741+
mlirStringAttrGet(context->get(), toMlirStringRef(value));
1742+
return PyStringAttribute(context->getRef(), attr);
1743+
},
1744+
nb::arg("value"), nb::arg("context") = nb::none(),
1745+
"Gets a uniqued string attribute");
1746+
c.def_static(
1747+
"get",
1748+
[](const nb::bytes &value, DefaultingPyMlirContext context) {
1749+
MlirAttribute attr =
1750+
mlirStringAttrGet(context->get(), toMlirStringRef(value));
1751+
return PyStringAttribute(context->getRef(), attr);
1752+
},
1753+
nb::arg("value"), nb::arg("context") = nb::none(),
1754+
"Gets a uniqued string attribute");
1755+
c.def_static(
1756+
"get_typed",
1757+
[](PyType &type, const std::string &value) {
1758+
MlirAttribute attr =
1759+
mlirStringAttrTypedGet(type, toMlirStringRef(value));
1760+
return PyStringAttribute(type.getContext(), attr);
1761+
},
1762+
nb::arg("type"), nb::arg("value"),
1763+
"Gets a uniqued string attribute associated to a type");
1764+
c.def_prop_ro(
1765+
"value",
1766+
[](PyStringAttribute &self) {
1767+
MlirStringRef stringRef = mlirStringAttrGetValue(self);
1768+
return nb::str(stringRef.data, stringRef.length);
1769+
},
1770+
"Returns the value of the string attribute");
1771+
c.def_prop_ro(
1772+
"value_bytes",
1773+
[](PyStringAttribute &self) {
1774+
MlirStringRef stringRef = mlirStringAttrGetValue(self);
1775+
return nb::bytes(stringRef.data, stringRef.length);
1776+
},
1777+
"Returns the value of the string attribute as `bytes`");
1778+
}
1779+
17641780
void mlir::python::populateIRAttributes(nb::module_ &m) {
17651781
PyAffineMapAttribute::bind(m);
17661782
PyDenseBoolArrayAttribute::bind(m);

0 commit comments

Comments
 (0)