@@ -485,7 +485,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
485485
486486 PyArrayAttributeIterator &dunderIter () { return *this ; }
487487
488- nb::object dunderNext () {
488+ nb::typed<nb:: object, PyAttribute> dunderNext () {
489489 // TODO: Throw is an inefficient way to stop iteration.
490490 if (nextIndex >= mlirArrayAttrGetNumElements (attr.get ()))
491491 throw nb::stop_iteration ();
@@ -526,7 +526,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
526526 " Gets a uniqued Array attribute" );
527527 c.def (
528528 " __getitem__" ,
529- [](PyArrayAttribute &arr, intptr_t i) {
529+ [](PyArrayAttribute &arr,
530+ intptr_t i) -> nb::typed<nb::object, PyAttribute> {
530531 if (i >= mlirArrayAttrGetNumElements (arr))
531532 throw nb::index_error (" ArrayAttribute index out of range" );
532533 return PyAttribute (arr.getContext (), arr.getItem (i)).maybeDownCast ();
@@ -1010,14 +1011,16 @@ class PyDenseElementsAttribute
10101011 [](PyDenseElementsAttribute &self) -> bool {
10111012 return mlirDenseElementsAttrIsSplat (self);
10121013 })
1013- .def (" get_splat_value" , [](PyDenseElementsAttribute &self) {
1014- if (!mlirDenseElementsAttrIsSplat (self))
1015- throw nb::value_error (
1016- " get_splat_value called on a non-splat attribute" );
1017- return PyAttribute (self.getContext (),
1018- mlirDenseElementsAttrGetSplatValue (self))
1019- .maybeDownCast ();
1020- });
1014+ .def (" get_splat_value" ,
1015+ [](PyDenseElementsAttribute &self)
1016+ -> nb::typed<nb::object, PyAttribute> {
1017+ if (!mlirDenseElementsAttrIsSplat (self))
1018+ throw nb::value_error (
1019+ " get_splat_value called on a non-splat attribute" );
1020+ return PyAttribute (self.getContext (),
1021+ mlirDenseElementsAttrGetSplatValue (self))
1022+ .maybeDownCast ();
1023+ });
10211024 }
10221025
10231026 static PyType_Slot slots[];
@@ -1332,7 +1335,7 @@ class PyDenseIntElementsAttribute
13321335
13331336 // / Returns the element at the given linear position. Asserts if the index
13341337 // / is out of range.
1335- nb::object dunderGetItem (intptr_t pos) {
1338+ nb::int_ dunderGetItem (intptr_t pos) {
13361339 if (pos < 0 || pos >= dunderLen ()) {
13371340 throw nb::index_error (" attempt to access out of bounds element" );
13381341 }
@@ -1522,13 +1525,15 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
15221525 },
15231526 nb::arg (" value" ) = nb::dict (), nb::arg (" context" ) = nb::none (),
15241527 " Gets an uniqued dict attribute" );
1525- c.def (" __getitem__" , [](PyDictAttribute &self, const std::string &name) {
1526- MlirAttribute attr =
1527- mlirDictionaryAttrGetElementByName (self, toMlirStringRef (name));
1528- if (mlirAttributeIsNull (attr))
1529- throw nb::key_error (" attempt to access a non-existent attribute" );
1530- return PyAttribute (self.getContext (), attr).maybeDownCast ();
1531- });
1528+ c.def (" __getitem__" ,
1529+ [](PyDictAttribute &self,
1530+ const std::string &name) -> nb::typed<nb::object, PyAttribute> {
1531+ MlirAttribute attr =
1532+ mlirDictionaryAttrGetElementByName (self, toMlirStringRef (name));
1533+ if (mlirAttributeIsNull (attr))
1534+ throw nb::key_error (" attempt to access a non-existent attribute" );
1535+ return PyAttribute (self.getContext (), attr).maybeDownCast ();
1536+ });
15321537 c.def (" __getitem__" , [](PyDictAttribute &self, intptr_t index) {
15331538 if (index < 0 || index >= self.dunderLen ()) {
15341539 throw nb::index_error (" attempt to access out of bounds attribute" );
@@ -1594,10 +1599,11 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
15941599 },
15951600 nb::arg (" value" ), nb::arg (" context" ) = nb::none (),
15961601 " Gets a uniqued Type attribute" );
1597- c.def_prop_ro (" value" , [](PyTypeAttribute &self) {
1598- return PyType (self.getContext (), mlirTypeAttrGetValue (self.get ()))
1599- .maybeDownCast ();
1600- });
1602+ c.def_prop_ro (
1603+ " value" , [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
1604+ return PyType (self.getContext (), mlirTypeAttrGetValue (self.get ()))
1605+ .maybeDownCast ();
1606+ });
16011607 }
16021608};
16031609
0 commit comments