@@ -167,7 +167,7 @@ struct nb_buffer_info {
167167};
168168
169169class nb_buffer : public nb ::object {
170- NB_OBJECT_DEFAULT (nb_buffer, object, " buffer " , PyObject_CheckBuffer);
170+ NB_OBJECT_DEFAULT (nb_buffer, object, " Buffer " , PyObject_CheckBuffer);
171171
172172 nb_buffer_info request () const {
173173 int flags = PyBUF_STRIDES | PyBUF_FORMAT;
@@ -252,8 +252,13 @@ class PyAffineMapAttribute : public PyConcreteAttribute<PyAffineMapAttribute> {
252252 return PyAffineMapAttribute (affineMap.getContext (), attr);
253253 },
254254 nb::arg (" affine_map" ), " Gets an attribute wrapping an AffineMap." );
255- c.def_prop_ro (" value" , mlirAffineMapAttrGetValue,
256- " Returns the value of the AffineMap attribute" );
255+ c.def_prop_ro (
256+ " value" ,
257+ [](PyAffineMapAttribute &self) {
258+ return PyAffineMap (self.getContext (),
259+ mlirAffineMapAttrGetValue (self));
260+ },
261+ " Returns the value of the AffineMap attribute" );
257262 }
258263};
259264
@@ -480,11 +485,13 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
480485
481486 PyArrayAttributeIterator &dunderIter () { return *this ; }
482487
483- MlirAttribute dunderNext () {
488+ nb::typed<nb::object, PyAttribute> dunderNext () {
484489 // TODO: Throw is an inefficient way to stop iteration.
485490 if (nextIndex >= mlirArrayAttrGetNumElements (attr.get ()))
486491 throw nb::stop_iteration ();
487- return mlirArrayAttrGetElement (attr.get (), nextIndex++);
492+ return PyAttribute (this ->attr .getContext (),
493+ mlirArrayAttrGetElement (attr.get (), nextIndex++))
494+ .maybeDownCast ();
488495 }
489496
490497 static void bind (nb::module_ &m) {
@@ -517,12 +524,13 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
517524 },
518525 nb::arg (" attributes" ), nb::arg (" context" ) = nb::none (),
519526 " Gets a uniqued Array attribute" );
520- c.def (" __getitem__" ,
521- [](PyArrayAttribute &arr, intptr_t i) {
522- if (i >= mlirArrayAttrGetNumElements (arr))
523- throw nb::index_error (" ArrayAttribute index out of range" );
524- return arr.getItem (i);
525- })
527+ c.def (
528+ " __getitem__" ,
529+ [](PyArrayAttribute &arr, intptr_t i) {
530+ if (i >= mlirArrayAttrGetNumElements (arr))
531+ throw nb::index_error (" ArrayAttribute index out of range" );
532+ return PyAttribute (arr.getContext (), arr.getItem (i)).maybeDownCast ();
533+ })
526534 .def (" __len__" ,
527535 [](const PyArrayAttribute &arr) {
528536 return mlirArrayAttrGetNumElements (arr);
@@ -611,10 +619,12 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
611619 " Returns the value of the integer attribute" );
612620 c.def (" __int__" , toPyInt,
613621 " Converts the value of the integer attribute to a Python int" );
614- c.def_prop_ro_static (" static_typeid" ,
615- [](nb::object & /* class*/ ) -> MlirTypeID {
616- return mlirIntegerAttrGetTypeID ();
617- });
622+ c.def_prop_ro_static (
623+ " static_typeid" ,
624+ [](nb::object & /* class*/ ) {
625+ return PyTypeID (mlirIntegerAttrGetTypeID ());
626+ },
627+ nanobind::sig (" def static_typeid(/) -> TypeID" ));
618628 }
619629
620630private:
@@ -657,8 +667,8 @@ class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
657667 static constexpr const char *pyClassName = " SymbolRefAttr" ;
658668 using PyConcreteAttribute::PyConcreteAttribute;
659669
660- static MlirAttribute fromList (const std::vector<std::string> &symbols,
661- PyMlirContext &context) {
670+ static PySymbolRefAttribute fromList (const std::vector<std::string> &symbols,
671+ PyMlirContext &context) {
662672 if (symbols.empty ())
663673 throw std::runtime_error (" SymbolRefAttr must be composed of at least "
664674 " one symbol." );
@@ -668,8 +678,10 @@ class PySymbolRefAttribute : public PyConcreteAttribute<PySymbolRefAttribute> {
668678 referenceAttrs.push_back (
669679 mlirFlatSymbolRefAttrGet (context.get (), toMlirStringRef (symbols[i])));
670680 }
671- return mlirSymbolRefAttrGet (context.get (), rootSymbol,
672- referenceAttrs.size (), referenceAttrs.data ());
681+ return PySymbolRefAttribute (context.getRef (),
682+ mlirSymbolRefAttrGet (context.get (), rootSymbol,
683+ referenceAttrs.size (),
684+ referenceAttrs.data ()));
673685 }
674686
675687 static void bindDerived (ClassTy &c) {
@@ -746,7 +758,11 @@ class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
746758 return PyOpaqueAttribute (context->getRef (), attr);
747759 },
748760 nb::arg (" dialect_namespace" ), nb::arg (" buffer" ), nb::arg (" type" ),
749- nb::arg (" context" ) = nb::none (), " Gets an Opaque attribute." );
761+ nb::arg (" context" ) = nb::none (),
762+ // clang-format off
763+ nb::sig (" def get(dialect_namespace: str, buffer: typing_extensions.Buffer, type: Type, context: Context | None = None) -> OpaqueAttr" ),
764+ // clang-format on
765+ " Gets an Opaque attribute." );
750766 c.def_prop_ro (
751767 " dialect_namespace" ,
752768 [](PyOpaqueAttribute &self) {
@@ -764,59 +780,6 @@ class PyOpaqueAttribute : public PyConcreteAttribute<PyOpaqueAttribute> {
764780 }
765781};
766782
767- class PyStringAttribute : public PyConcreteAttribute <PyStringAttribute> {
768- public:
769- static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
770- static constexpr const char *pyClassName = " StringAttr" ;
771- using PyConcreteAttribute::PyConcreteAttribute;
772- static constexpr GetTypeIDFunctionTy getTypeIdFunction =
773- mlirStringAttrGetTypeID;
774-
775- static void bindDerived (ClassTy &c) {
776- c.def_static (
777- " get" ,
778- [](const std::string &value, DefaultingPyMlirContext context) {
779- MlirAttribute attr =
780- mlirStringAttrGet (context->get (), toMlirStringRef (value));
781- return PyStringAttribute (context->getRef (), attr);
782- },
783- nb::arg (" value" ), nb::arg (" context" ) = nb::none (),
784- " Gets a uniqued string attribute" );
785- c.def_static (
786- " get" ,
787- [](const nb::bytes &value, DefaultingPyMlirContext context) {
788- MlirAttribute attr =
789- mlirStringAttrGet (context->get (), toMlirStringRef (value));
790- return PyStringAttribute (context->getRef (), attr);
791- },
792- nb::arg (" value" ), nb::arg (" context" ) = nb::none (),
793- " Gets a uniqued string attribute" );
794- c.def_static (
795- " get_typed" ,
796- [](PyType &type, const std::string &value) {
797- MlirAttribute attr =
798- mlirStringAttrTypedGet (type, toMlirStringRef (value));
799- return PyStringAttribute (type.getContext (), attr);
800- },
801- nb::arg (" type" ), nb::arg (" value" ),
802- " Gets a uniqued string attribute associated to a type" );
803- c.def_prop_ro (
804- " value" ,
805- [](PyStringAttribute &self) {
806- MlirStringRef stringRef = mlirStringAttrGetValue (self);
807- return nb::str (stringRef.data , stringRef.length );
808- },
809- " Returns the value of the string attribute" );
810- c.def_prop_ro (
811- " value_bytes" ,
812- [](PyStringAttribute &self) {
813- MlirStringRef stringRef = mlirStringAttrGetValue (self);
814- return nb::bytes (stringRef.data , stringRef.length );
815- },
816- " Returns the value of the string attribute as `bytes`" );
817- }
818- };
819-
820783// TODO: Support construction of string elements.
821784class PyDenseElementsAttribute
822785 : public PyConcreteAttribute<PyDenseElementsAttribute> {
@@ -1028,11 +991,14 @@ class PyDenseElementsAttribute
1028991 PyDenseElementsAttribute::bf_releasebuffer;
1029992#endif
1030993 c.def (" __len__" , &PyDenseElementsAttribute::dunderLen)
1031- .def_static (" get" , PyDenseElementsAttribute::getFromBuffer,
1032- nb::arg (" array" ), nb::arg (" signless" ) = true ,
1033- nb::arg (" type" ) = nb::none (), nb::arg (" shape" ) = nb::none (),
1034- nb::arg (" context" ) = nb::none (),
1035- kDenseElementsAttrGetDocstring )
994+ .def_static (
995+ " get" , PyDenseElementsAttribute::getFromBuffer, nb::arg (" array" ),
996+ nb::arg (" signless" ) = true , nb::arg (" type" ) = nb::none (),
997+ nb::arg (" shape" ) = nb::none (), nb::arg (" context" ) = nb::none (),
998+ // clang-format off
999+ nb::sig (" def get(array: typing_extensions.Buffer, signless: bool = True, type: Type | None = None, shape: Sequence[int] | None = None, context: Context | None = None) -> DenseElementsAttr" ),
1000+ // clang-format on
1001+ kDenseElementsAttrGetDocstring )
10361002 .def_static (" get" , PyDenseElementsAttribute::getFromList,
10371003 nb::arg (" attrs" ), nb::arg (" type" ) = nb::none (),
10381004 nb::arg (" context" ) = nb::none (),
@@ -1048,7 +1014,9 @@ class PyDenseElementsAttribute
10481014 if (!mlirDenseElementsAttrIsSplat (self))
10491015 throw nb::value_error (
10501016 " get_splat_value called on a non-splat attribute" );
1051- return mlirDenseElementsAttrGetSplatValue (self);
1017+ return PyAttribute (self.getContext (),
1018+ mlirDenseElementsAttrGetSplatValue (self))
1019+ .maybeDownCast ();
10521020 });
10531021 }
10541022
@@ -1509,6 +1477,9 @@ class PyDenseResourceElementsAttribute
15091477 nb::arg (" array" ), nb::arg (" name" ), nb::arg (" type" ),
15101478 nb::arg (" alignment" ) = nb::none (),
15111479 nb::arg (" is_mutable" ) = false , nb::arg (" context" ) = nb::none (),
1480+ // clang-format off
1481+ nb::sig (" def get_from_buffer(array: typing_extensions.Buffer, name: str, type: Type, alignment: int | None = None, is_mutable: bool = False, context: Context | None = None) -> DenseResourceElementsAttr" ),
1482+ // clang-format on
15121483 kDenseResourceElementsAttrGetFromBufferDocstring );
15131484 }
15141485};
@@ -1556,7 +1527,7 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
15561527 mlirDictionaryAttrGetElementByName (self, toMlirStringRef (name));
15571528 if (mlirAttributeIsNull (attr))
15581529 throw nb::key_error (" attempt to access a non-existent attribute" );
1559- return attr;
1530+ return PyAttribute (self. getContext (), attr). maybeDownCast () ;
15601531 });
15611532 c.def (" __getitem__" , [](PyDictAttribute &self, intptr_t index) {
15621533 if (index < 0 || index >= self.dunderLen ()) {
@@ -1624,7 +1595,8 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
16241595 nb::arg (" value" ), nb::arg (" context" ) = nb::none (),
16251596 " Gets a uniqued Type attribute" );
16261597 c.def_prop_ro (" value" , [](PyTypeAttribute &self) {
1627- return mlirTypeAttrGetValue (self.get ());
1598+ return PyType (self.getContext (), mlirTypeAttrGetValue (self.get ()))
1599+ .maybeDownCast ();
16281600 });
16291601 }
16301602};
@@ -1761,6 +1733,50 @@ nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) {
17611733
17621734} // namespace
17631735
1736+ void PyStringAttribute::bindDerived (ClassTy &c) {
1737+ c.def_static (
1738+ " get" ,
1739+ [](const std::string &value, DefaultingPyMlirContext context) {
1740+ MlirAttribute attr =
1741+ mlirStringAttrGet (context->get (), toMlirStringRef (value));
1742+ return PyStringAttribute (context->getRef (), attr);
1743+ },
1744+ nb::arg (" value" ), nb::arg (" context" ) = nb::none (),
1745+ " Gets a uniqued string attribute" );
1746+ c.def_static (
1747+ " get" ,
1748+ [](const nb::bytes &value, DefaultingPyMlirContext context) {
1749+ MlirAttribute attr =
1750+ mlirStringAttrGet (context->get (), toMlirStringRef (value));
1751+ return PyStringAttribute (context->getRef (), attr);
1752+ },
1753+ nb::arg (" value" ), nb::arg (" context" ) = nb::none (),
1754+ " Gets a uniqued string attribute" );
1755+ c.def_static (
1756+ " get_typed" ,
1757+ [](PyType &type, const std::string &value) {
1758+ MlirAttribute attr =
1759+ mlirStringAttrTypedGet (type, toMlirStringRef (value));
1760+ return PyStringAttribute (type.getContext (), attr);
1761+ },
1762+ nb::arg (" type" ), nb::arg (" value" ),
1763+ " Gets a uniqued string attribute associated to a type" );
1764+ c.def_prop_ro (
1765+ " value" ,
1766+ [](PyStringAttribute &self) {
1767+ MlirStringRef stringRef = mlirStringAttrGetValue (self);
1768+ return nb::str (stringRef.data , stringRef.length );
1769+ },
1770+ " Returns the value of the string attribute" );
1771+ c.def_prop_ro (
1772+ " value_bytes" ,
1773+ [](PyStringAttribute &self) {
1774+ MlirStringRef stringRef = mlirStringAttrGetValue (self);
1775+ return nb::bytes (stringRef.data , stringRef.length );
1776+ },
1777+ " Returns the value of the string attribute as `bytes`" );
1778+ }
1779+
17641780void mlir::python::populateIRAttributes (nb::module_ &m) {
17651781 PyAffineMapAttribute::bind (m);
17661782 PyDenseBoolArrayAttribute::bind (m);
0 commit comments