Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions mlir/lib/Bindings/Python/IRAffine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
})
.def_prop_ro(
"context",
[](PyAffineExpr &self) { return self.getContext().getObject(); })
[](PyAffineExpr &self) -> nb::typed<nb::object, PyMlirContext> {
return self.getContext().getObject();
})
.def("compose",
[](PyAffineExpr &self, PyAffineMap &other) {
return PyAffineExpr(self.getContext(),
Expand Down Expand Up @@ -706,28 +708,29 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
[](PyAffineMap &self) {
return static_cast<size_t>(llvm::hash_value(self.get().ptr));
})
.def_static("compress_unused_symbols",
[](const nb::list &affineMaps,
DefaultingPyMlirContext context) {
SmallVector<MlirAffineMap> maps;
pyListToVector<PyAffineMap, MlirAffineMap>(
affineMaps, maps, "attempting to create an AffineMap");
std::vector<MlirAffineMap> compressed(affineMaps.size());
auto populate = [](void *result, intptr_t idx,
MlirAffineMap m) {
static_cast<MlirAffineMap *>(result)[idx] = (m);
};
mlirAffineMapCompressUnusedSymbols(
maps.data(), maps.size(), compressed.data(), populate);
std::vector<PyAffineMap> res;
res.reserve(compressed.size());
for (auto m : compressed)
res.emplace_back(context->getRef(), m);
return res;
})
.def_static(
"compress_unused_symbols",
[](const nb::list &affineMaps, DefaultingPyMlirContext context) {
SmallVector<MlirAffineMap> maps;
pyListToVector<PyAffineMap, MlirAffineMap>(
affineMaps, maps, "attempting to create an AffineMap");
std::vector<MlirAffineMap> compressed(affineMaps.size());
auto populate = [](void *result, intptr_t idx, MlirAffineMap m) {
static_cast<MlirAffineMap *>(result)[idx] = (m);
};
mlirAffineMapCompressUnusedSymbols(maps.data(), maps.size(),
compressed.data(), populate);
std::vector<PyAffineMap> res;
res.reserve(compressed.size());
for (auto m : compressed)
res.emplace_back(context->getRef(), m);
return res;
})
.def_prop_ro(
"context",
[](PyAffineMap &self) { return self.getContext().getObject(); },
[](PyAffineMap &self) -> nb::typed<nb::object, PyMlirContext> {
return self.getContext().getObject();
},
"Context that owns the Affine Map")
.def(
"dump", [](PyAffineMap &self) { mlirAffineMapDump(self); },
Expand Down Expand Up @@ -893,7 +896,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) {
})
.def_prop_ro(
"context",
[](PyIntegerSet &self) { return self.getContext().getObject(); })
[](PyIntegerSet &self) -> nb::typed<nb::object, PyMlirContext> {
return self.getContext().getObject();
})
.def(
"dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); },
kDumpDocstring)
Expand Down
50 changes: 28 additions & 22 deletions mlir/lib/Bindings/Python/IRAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {

PyArrayAttributeIterator &dunderIter() { return *this; }

nb::object dunderNext() {
nb::typed<nb::object, PyAttribute> dunderNext() {
// TODO: Throw is an inefficient way to stop iteration.
if (nextIndex >= mlirArrayAttrGetNumElements(attr.get()))
throw nb::stop_iteration();
Expand Down Expand Up @@ -526,7 +526,8 @@ class PyArrayAttribute : public PyConcreteAttribute<PyArrayAttribute> {
"Gets a uniqued Array attribute");
c.def(
"__getitem__",
[](PyArrayAttribute &arr, intptr_t i) {
[](PyArrayAttribute &arr,
intptr_t i) -> nb::typed<nb::object, PyAttribute> {
if (i >= mlirArrayAttrGetNumElements(arr))
throw nb::index_error("ArrayAttribute index out of range");
return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast();
Expand Down Expand Up @@ -1010,14 +1011,16 @@ class PyDenseElementsAttribute
[](PyDenseElementsAttribute &self) -> bool {
return mlirDenseElementsAttrIsSplat(self);
})
.def("get_splat_value", [](PyDenseElementsAttribute &self) {
if (!mlirDenseElementsAttrIsSplat(self))
throw nb::value_error(
"get_splat_value called on a non-splat attribute");
return PyAttribute(self.getContext(),
mlirDenseElementsAttrGetSplatValue(self))
.maybeDownCast();
});
.def("get_splat_value",
[](PyDenseElementsAttribute &self)
-> nb::typed<nb::object, PyAttribute> {
if (!mlirDenseElementsAttrIsSplat(self))
throw nb::value_error(
"get_splat_value called on a non-splat attribute");
return PyAttribute(self.getContext(),
mlirDenseElementsAttrGetSplatValue(self))
.maybeDownCast();
});
}

static PyType_Slot slots[];
Expand Down Expand Up @@ -1332,7 +1335,7 @@ class PyDenseIntElementsAttribute

/// Returns the element at the given linear position. Asserts if the index
/// is out of range.
nb::object dunderGetItem(intptr_t pos) {
nb::int_ dunderGetItem(intptr_t pos) {
if (pos < 0 || pos >= dunderLen()) {
throw nb::index_error("attempt to access out of bounds element");
}
Expand Down Expand Up @@ -1522,13 +1525,15 @@ class PyDictAttribute : public PyConcreteAttribute<PyDictAttribute> {
},
nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(),
"Gets an uniqued dict attribute");
c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) {
MlirAttribute attr =
mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
if (mlirAttributeIsNull(attr))
throw nb::key_error("attempt to access a non-existent attribute");
return PyAttribute(self.getContext(), attr).maybeDownCast();
});
c.def("__getitem__",
[](PyDictAttribute &self,
const std::string &name) -> nb::typed<nb::object, PyAttribute> {
MlirAttribute attr =
mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name));
if (mlirAttributeIsNull(attr))
throw nb::key_error("attempt to access a non-existent attribute");
return PyAttribute(self.getContext(), attr).maybeDownCast();
});
c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) {
if (index < 0 || index >= self.dunderLen()) {
throw nb::index_error("attempt to access out of bounds attribute");
Expand Down Expand Up @@ -1594,10 +1599,11 @@ class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
},
nb::arg("value"), nb::arg("context") = nb::none(),
"Gets a uniqued Type attribute");
c.def_prop_ro("value", [](PyTypeAttribute &self) {
return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
.maybeDownCast();
});
c.def_prop_ro(
"value", [](PyTypeAttribute &self) -> nb::typed<nb::object, PyType> {
return PyType(self.getContext(), mlirTypeAttrGetValue(self.get()))
.maybeDownCast();
});
}
};

Expand Down
Loading