Skip to content

Commit 3376f4f

Browse files
committed
[MLIR][Python] use nb::typed for return signatures
1 parent 18f7e03 commit 3376f4f

File tree

4 files changed

+69
-50
lines changed

4 files changed

+69
-50
lines changed

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
485485

486486
PyArrayAttributeIterator &dunderIter() { return *this; }
487487

488-
nb::object dunderNext() {
488+
nb::typed<nb::object, PyAttribute> dunderNext() {
489489
// TODO: Throw is an inefficient way to stop iteration.
490490
if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
491491
throw nb::stop_iteration();
@@ -526,7 +526,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
526526
"Gets a uniqued Array attribute");
527527
c.def(
528528
"__getitem__",
529-
[](PyArrayAttribute &arr, intptr_t i) {
529+
[](PyArrayAttribute &arr,
530+
intptr_t i) -> nb::typed<nb::object, PyAttribute> {
530531
if (i >= mlirArrayAttrGetNumElements(arr))
531532
throw nb::index_error("ArrayAttribute index out of range");
532533
return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
@@ -1010,14 +1011,16 @@ class PyDenseElementsAttribute
10101011
[](PyDenseElementsAttribute &self) -> bool {
10111012
return mlirDenseElementsAttrIsSplat(self);
10121013
})
1013-
.def("get_splat_value", [](PyDenseElementsAttribute &self) {
1014-
if (!mlirDenseElementsAttrIsSplat(self))
1015-
throw nb::value_error(
1016-
"get_splat_value called on a non-splat attribute");
1017-
return PyAttribute(self.getContext(),
1018-
mlirDenseElementsAttrGetSplatValue(self))
1019-
.maybeDownCast();
1020-
});
1014+
.def("get_splat_value",
1015+
[](PyDenseElementsAttribute &self)
1016+
-> nb::typed<nb::object, PyAttribute> {
1017+
if (!mlirDenseElementsAttrIsSplat(self))
1018+
throw nb::value_error(
1019+
"get_splat_value called on a non-splat attribute");
1020+
return PyAttribute(self.getContext(),
1021+
mlirDenseElementsAttrGetSplatValue(self))
1022+
.maybeDownCast();
1023+
});
10211024
}
10221025

10231026
static PyType_Slot slots[];
@@ -1522,13 +1525,15 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
15221525
},
15231526
nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
15241527
"Gets an uniqued dict attribute");
1525-
c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
1526-
MlirAttribute attr =
1527-
mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1528-
if (mlirAttributeIsNull(attr))
1529-
throw nb::key_error("attempt to access a non-existent attribute");
1530-
return PyAttribute(self.getContext(), attr).maybeDownCast();
1531-
});
1528+
c.def("__getitem__",
1529+
[](PyDictAttribute &self,
1530+
const std::string &name) -> nb::typed<nb::object, PyAttribute> {
1531+
MlirAttribute attr =
1532+
mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
1533+
if (mlirAttributeIsNull(attr))
1534+
throw nb::key_error("attempt to access a non-existent attribute");
1535+
return PyAttribute(self.getContext(), attr).maybeDownCast();
1536+
});
15321537
c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
15331538
if (index < 0 || index >= self.dunderLen()) {
15341539
throw nb::index_error("attempt to access out of bounds attribute");
@@ -1594,10 +1599,11 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
15941599
},
15951600
nb::arg("value"), nb::arg("context") = nb::none(),
15961601
"Gets a uniqued Type attribute");
1597-
c.def_prop_ro("value", [](PyTypeAttribute &self) {
1598-
return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
1599-
.maybeDownCast();
1600-
});
1602+
c.def_prop_ro(
1603+
"value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
1604+
return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
1605+
.maybeDownCast();
1606+
});
16011607
}
16021608
};
16031609

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,7 +1605,9 @@ class PyConcreteValue : public PyValue {
16051605
},
16061606
nb::arg("other_value"));
16071607
cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
1608-
[](DerivedTy &self) { return self.maybeDownCast(); });
1608+
[](DerivedTy &self) -> nb::typed<nb::object, DerivedTy> {
1609+
return self.maybeDownCast();
1610+
});
16091611
DerivedTy::bindDerived(cls);
16101612
}
16111613

@@ -1638,9 +1640,9 @@ class PyOpResult : public PyConcreteValue<PyOpResult> {
16381640

16391641
/// Returns the list of types of the values held by container.
16401642
template <typename Container>
1641-
static std::vector<nb::object> getValueTypes(Container &container,
1642-
PyMlirContextRef &context) {
1643-
std::vector<nb::object> result;
1643+
static std::vector<nb::typed<nb::object, PyType>>
1644+
getValueTypes(Container &container, PyMlirContextRef &context) {
1645+
std::vector<nb::typed<nb::object, PyType>> result;
16441646
result.reserve(container.size());
16451647
for (int i = 0, e = container.size(); i < e; ++i) {
16461648
result.push_back(PyType(context->getRef(),
@@ -2677,7 +2679,8 @@ class PyOpAttributeMap {
26772679
PyOpAttributeMap(PyOperationRef operation)
26782680
: operation(std::move(operation)) {}
26792681

2680-
nb::object dunderGetItemNamed(const std::string &name) {
2682+
nb::typed<nb::object, PyAttribute>
2683+
dunderGetItemNamed(const std::string &name) {
26812684
MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
26822685
toMlirStringRef(name));
26832686
if (mlirAttributeIsNull(attr)) {
@@ -3461,7 +3464,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
34613464
"Returns the list of Operation results.")
34623465
.def_prop_ro(
34633466
"result",
3464-
[](PyOperationBase &self) {
3467+
[](PyOperationBase &self) -> nb::typed<nb::object, PyOpResult> {
34653468
auto &operation = self.getOperation();
34663469
return PyOpResult(operation.getRef(), getUniqueResult(operation))
34673470
.maybeDownCast();
@@ -3982,7 +3985,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
39823985
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
39833986
.def_static(
39843987
"parse",
3985-
[](const std::string &attrSpec, DefaultingPyMlirContext context) {
3988+
[](const std::string &attrSpec, DefaultingPyMlirContext context)
3989+
-> nb::typed<nb::object, PyAttribute> {
39863990
PyMlirContext::ErrorCapture errors(context->getRef());
39873991
MlirAttribute attr = mlirAttributeParseGet(
39883992
context->get(), toMlirStringRef(attrSpec));
@@ -3998,7 +4002,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
39984002
[](PyAttribute &self) { return self.getContext().getObject(); },
39994003
"Context that owns the Attribute")
40004004
.def_prop_ro("type",
4001-
[](PyAttribute &self) {
4005+
[](PyAttribute &self) -> nb::typed<nb::object, PyType> {
40024006
return PyType(self.getContext(),
40034007
mlirAttributeGetType(self))
40044008
.maybeDownCast();
@@ -4049,7 +4053,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
40494053
"mlirTypeID was expected to be non-null.");
40504054
return PyTypeID(mlirTypeID);
40514055
})
4052-
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyAttribute::maybeDownCast);
4056+
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
4057+
[](PyAttribute &self) -> nb::typed<nb::object, PyAttribute> {
4058+
return self.maybeDownCast();
4059+
});
40534060

40544061
//----------------------------------------------------------------------------
40554062
// Mapping of PyNamedAttribute
@@ -4094,7 +4101,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
40944101
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
40954102
.def_static(
40964103
"parse",
4097-
[](std::string typeSpec, DefaultingPyMlirContext context) {
4104+
[](std::string typeSpec,
4105+
DefaultingPyMlirContext context) -> nb::typed<nb::object, PyType> {
40984106
PyMlirContext::ErrorCapture errors(context->getRef());
40994107
MlirType type =
41004108
mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
@@ -4139,7 +4147,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
41394147
printAccum.parts.append(")");
41404148
return printAccum.join();
41414149
})
4142-
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyType::maybeDownCast)
4150+
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
4151+
[](PyType &self) -> nb::typed<nb::object, PyType> {
4152+
return self.maybeDownCast();
4153+
})
41434154
.def_prop_ro("typeid", [](PyType &self) {
41444155
MlirTypeID mlirTypeID = mlirTypeGetTypeID(self);
41454156
if (!mlirTypeIDIsNull(mlirTypeID))
@@ -4266,7 +4277,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
42664277
},
42674278
nb::arg("state"), kGetNameAsOperand)
42684279
.def_prop_ro("type",
4269-
[](PyValue &self) {
4280+
[](PyValue &self) -> nb::typed<nb::object, PyType> {
42704281
return PyType(self.getParentOperation()->getContext(),
42714282
mlirValueGetType(self.get()))
42724283
.maybeDownCast();
@@ -4332,7 +4343,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
43324343
},
43334344
nb::arg("with_"), nb::arg("exceptions"),
43344345
kValueReplaceAllUsesExceptDocstring)
4335-
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, &PyValue::maybeDownCast)
4346+
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
4347+
[](PyValue &self) -> nb::typed<nb::object, PyValue> {
4348+
return self.maybeDownCast();
4349+
})
43364350
.def_prop_ro(
43374351
"location",
43384352
[](MlirValue self) {

mlir/lib/Bindings/Python/IRModule.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,10 +1101,12 @@ class PyConcreteAttribute : public BaseTy {
11011101
return DerivedTy::isaFunction(otherAttr);
11021102
},
11031103
nanobind::arg("other"));
1104-
cls.def_prop_ro("type", [](PyAttribute &attr) {
1105-
return PyType(attr.getContext(), mlirAttributeGetType(attr))
1106-
.maybeDownCast();
1107-
});
1104+
cls.def_prop_ro(
1105+
"type",
1106+
[](PyAttribute &attr) -> nanobind::typed<nanobind::object, PyType> {
1107+
return PyType(attr.getContext(), mlirAttributeGetType(attr))
1108+
.maybeDownCast();
1109+
});
11081110
cls.def_prop_ro_static(
11091111
"static_typeid",
11101112
[](nanobind::object & /*class*/) -> PyTypeID {

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
501501
"Create a complex type");
502502
c.def_prop_ro(
503503
"element_type",
504-
[](PyComplexType &self) {
504+
[](PyComplexType &self) -> nb::typed<nb::object, PyType> {
505505
return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
506506
.maybeDownCast();
507507
},
@@ -515,7 +515,7 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
515515
void mlir::PyShapedType::bindDerived(ClassTy &c) {
516516
c.def_prop_ro(
517517
"element_type",
518-
[](PyShapedType &self) {
518+
[](PyShapedType &self) -> nb::typed<nb::object, PyType> {
519519
return PyType(self.getContext(), mlirShapedTypeGetElementType(self))
520520
.maybeDownCast();
521521
},
@@ -731,8 +731,7 @@ class PyRankedTensorType
731731
MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
732732
if (mlirAttributeIsNull(encoding))
733733
return std::nullopt;
734-
return nb::cast<nb::typed<nb::object, PyAttribute>>(
735-
PyAttribute(self.getContext(), encoding).maybeDownCast());
734+
return PyAttribute(self.getContext(), encoding).maybeDownCast();
736735
});
737736
}
738737
};
@@ -794,9 +793,9 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
794793
.def_prop_ro(
795794
"layout",
796795
[](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
797-
return nb::cast<nb::typed<nb::object, PyAttribute>>(
798-
PyAttribute(self.getContext(), mlirMemRefTypeGetLayout(self))
799-
.maybeDownCast());
796+
return PyAttribute(self.getContext(),
797+
mlirMemRefTypeGetLayout(self))
798+
.maybeDownCast();
800799
},
801800
"The layout of the MemRef type.")
802801
.def(
@@ -825,8 +824,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
825824
MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
826825
if (mlirAttributeIsNull(a))
827826
return std::nullopt;
828-
return nb::cast<nb::typed<nb::object, PyAttribute>>(
829-
PyAttribute(self.getContext(), a).maybeDownCast());
827+
return PyAttribute(self.getContext(), a).maybeDownCast();
830828
},
831829
"Returns the memory space of the given MemRef type.");
832830
}
@@ -867,8 +865,7 @@ class PyUnrankedMemRefType
867865
MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
868866
if (mlirAttributeIsNull(a))
869867
return std::nullopt;
870-
return nb::cast<nb::typed<nb::object, PyAttribute>>(
871-
PyAttribute(self.getContext(), a).maybeDownCast());
868+
return PyAttribute(self.getContext(), a).maybeDownCast();
872869
},
873870
"Returns the memory space of the given Unranked MemRef type.");
874871
}
@@ -912,7 +909,7 @@ class PyTupleType : public PyConcreteType<PyTupleType> {
912909
"Create a tuple type");
913910
c.def(
914911
"get_type",
915-
[](PyTupleType &self, intptr_t pos) {
912+
[](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
916913
return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
917914
.maybeDownCast();
918915
},

0 commit comments

Comments
 (0)