Skip to content

Commit 3fb3522

Browse files
committed
[MLIR][Python] rename checked gettors and add unchecked gettors
1 parent d2f14bc commit 3fb3522

File tree

3 files changed

+169
-25
lines changed

3 files changed

+169
-25
lines changed

mlir/lib/Bindings/Python/DialectLLVM.cpp

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,37 @@ static void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
3333
auto llvmStructType =
3434
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
3535

36-
llvmStructType.def_classmethod(
37-
"get_literal",
38-
[](const nb::object &cls, const std::vector<MlirType> &elements,
39-
bool packed, MlirLocation loc) {
40-
CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
41-
42-
MlirType type = mlirLLVMStructTypeLiteralGetChecked(
43-
loc, elements.size(), elements.data(), packed);
44-
if (mlirTypeIsNull(type)) {
45-
throw nb::value_error(scope.takeMessage().c_str());
46-
}
47-
return cls(type);
48-
},
49-
"cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
50-
"loc"_a = nb::none());
36+
llvmStructType
37+
.def_classmethod(
38+
"get_literal_checked",
39+
[](const nb::object &cls, const std::vector<MlirType> &elements,
40+
bool packed, MlirLocation loc) {
41+
CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
42+
43+
MlirType type = mlirLLVMStructTypeLiteralGetChecked(
44+
loc, elements.size(), elements.data(), packed);
45+
if (mlirTypeIsNull(type)) {
46+
throw nb::value_error(scope.takeMessage().c_str());
47+
}
48+
return cls(type);
49+
},
50+
"cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
51+
"loc"_a = nb::none())
52+
.def_classmethod(
53+
"get_literal",
54+
[](const nb::object &cls, const std::vector<MlirType> &elements,
55+
bool packed, MlirContext context) {
56+
CollectDiagnosticsToStringScope scope(context);
57+
58+
MlirType type = mlirLLVMStructTypeLiteralGet(
59+
context, elements.size(), elements.data(), packed);
60+
if (mlirTypeIsNull(type)) {
61+
throw nb::value_error(scope.takeMessage().c_str());
62+
}
63+
return cls(type);
64+
},
65+
"cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
66+
"context"_a = nb::none());
5167

5268
llvmStructType.def_classmethod(
5369
"get_identified",

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
565565

566566
static void bindDerived(ClassTy &c) {
567567
c.def_static(
568-
"get",
568+
"get_checked",
569569
[](PyType &type, double value, DefaultingPyLocation loc) {
570570
PyMlirContext::ErrorCapture errors(loc->getContext());
571571
MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value);
@@ -575,6 +575,18 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
575575
},
576576
nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
577577
"Gets an uniqued float point attribute associated to a type");
578+
c.def_static(
579+
"get",
580+
[](PyType &type, double value, DefaultingPyMlirContext context) {
581+
PyMlirContext::ErrorCapture errors(context->getRef());
582+
MlirAttribute attr =
583+
mlirFloatAttrDoubleGet(context.get()->get(), type, value);
584+
if (mlirAttributeIsNull(attr))
585+
throw MLIRError("Invalid attribute", errors.take());
586+
return PyFloatAttribute(type.getContext(), attr);
587+
},
588+
nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
589+
"Gets an uniqued float point attribute associated to a type");
578590
c.def_static(
579591
"get_f32",
580592
[](double value, DefaultingPyMlirContext context) {

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 125 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,12 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
643643
nb::arg("element_type"), nb::kw_only(),
644644
nb::arg("scalable") = nb::none(),
645645
nb::arg("scalable_dims") = nb::none(),
646-
nb::arg("loc") = nb::none(), "Create a vector type")
646+
nb::arg("context") = nb::none(), "Create a vector type")
647+
.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
648+
nb::arg("element_type"), nb::kw_only(),
649+
nb::arg("scalable") = nb::none(),
650+
nb::arg("scalable_dims") = nb::none(),
651+
nb::arg("loc") = nb::none(), "Create a vector type")
647652
.def_prop_ro(
648653
"scalable",
649654
[](MlirType self) { return mlirVectorTypeIsScalable(self); })
@@ -658,10 +663,11 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
658663
}
659664

660665
private:
661-
static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
662-
std::optional<nb::list> scalable,
663-
std::optional<std::vector<int64_t>> scalableDims,
664-
DefaultingPyLocation loc) {
666+
static PyVectorType
667+
getChecked(std::vector<int64_t> shape, PyType &elementType,
668+
std::optional<nb::list> scalable,
669+
std::optional<std::vector<int64_t>> scalableDims,
670+
DefaultingPyLocation loc) {
665671
if (scalable && scalableDims) {
666672
throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
667673
"are mutually exclusive.");
@@ -696,6 +702,42 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
696702
throw MLIRError("Invalid type", errors.take());
697703
return PyVectorType(elementType.getContext(), type);
698704
}
705+
706+
static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
707+
std::optional<nb::list> scalable,
708+
std::optional<std::vector<int64_t>> scalableDims,
709+
DefaultingPyMlirContext context) {
710+
if (scalable && scalableDims) {
711+
throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
712+
"are mutually exclusive.");
713+
}
714+
715+
PyMlirContext::ErrorCapture errors(context->getRef());
716+
MlirType type;
717+
if (scalable) {
718+
if (scalable->size() != shape.size())
719+
throw nb::value_error("Expected len(scalable) == len(shape).");
720+
721+
SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
722+
*scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
723+
type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
724+
scalableDimFlags.data(), elementType);
725+
} else if (scalableDims) {
726+
SmallVector<bool> scalableDimFlags(shape.size(), false);
727+
for (int64_t dim : *scalableDims) {
728+
if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
729+
throw nb::value_error("Scalable dimension index out of bounds.");
730+
scalableDimFlags[dim] = true;
731+
}
732+
type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
733+
scalableDimFlags.data(), elementType);
734+
} else {
735+
type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
736+
}
737+
if (mlirTypeIsNull(type))
738+
throw MLIRError("Invalid type", errors.take());
739+
return PyVectorType(elementType.getContext(), type);
740+
}
699741
};
700742

701743
/// Ranked Tensor Type subclass - RankedTensorType.
@@ -710,7 +752,7 @@ class PyRankedTensorType
710752

711753
static void bindDerived(ClassTy &c) {
712754
c.def_static(
713-
"get",
755+
"get_checked",
714756
[](std::vector<int64_t> shape, PyType &elementType,
715757
std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
716758
PyMlirContext::ErrorCapture errors(loc->getContext());
@@ -724,6 +766,22 @@ class PyRankedTensorType
724766
nb::arg("shape"), nb::arg("element_type"),
725767
nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
726768
"Create a ranked tensor type");
769+
c.def_static(
770+
"get",
771+
[](std::vector<int64_t> shape, PyType &elementType,
772+
std::optional<PyAttribute> &encodingAttr,
773+
DefaultingPyMlirContext context) {
774+
PyMlirContext::ErrorCapture errors(context->getRef());
775+
MlirType t = mlirRankedTensorTypeGet(
776+
shape.size(), shape.data(), elementType,
777+
encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
778+
if (mlirTypeIsNull(t))
779+
throw MLIRError("Invalid type", errors.take());
780+
return PyRankedTensorType(elementType.getContext(), t);
781+
},
782+
nb::arg("shape"), nb::arg("element_type"),
783+
nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
784+
"Create a ranked tensor type");
727785
c.def_prop_ro(
728786
"encoding",
729787
[](PyRankedTensorType &self)
@@ -748,7 +806,7 @@ class PyUnrankedTensorType
748806

749807
static void bindDerived(ClassTy &c) {
750808
c.def_static(
751-
"get",
809+
"get_checked",
752810
[](PyType &elementType, DefaultingPyLocation loc) {
753811
PyMlirContext::ErrorCapture errors(loc->getContext());
754812
MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
@@ -758,6 +816,17 @@ class PyUnrankedTensorType
758816
},
759817
nb::arg("element_type"), nb::arg("loc") = nb::none(),
760818
"Create a unranked tensor type");
819+
c.def_static(
820+
"get",
821+
[](PyType &elementType, DefaultingPyMlirContext context) {
822+
PyMlirContext::ErrorCapture errors(context->getRef());
823+
MlirType t = mlirUnrankedTensorTypeGet(elementType);
824+
if (mlirTypeIsNull(t))
825+
throw MLIRError("Invalid type", errors.take());
826+
return PyUnrankedTensorType(elementType.getContext(), t);
827+
},
828+
nb::arg("element_type"), nb::arg("context") = nb::none(),
829+
"Create a unranked tensor type");
761830
}
762831
};
763832

@@ -772,7 +841,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
772841

773842
static void bindDerived(ClassTy &c) {
774843
c.def_static(
775-
"get",
844+
"get_checked",
776845
[](std::vector<int64_t> shape, PyType &elementType,
777846
PyAttribute *layout, PyAttribute *memorySpace,
778847
DefaultingPyLocation loc) {
@@ -790,6 +859,27 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
790859
nb::arg("shape"), nb::arg("element_type"),
791860
nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
792861
nb::arg("loc") = nb::none(), "Create a memref type")
862+
.def_static(
863+
"get",
864+
[](std::vector<int64_t> shape, PyType &elementType,
865+
PyAttribute *layout, PyAttribute *memorySpace,
866+
DefaultingPyMlirContext context) {
867+
PyMlirContext::ErrorCapture errors(context->getRef());
868+
MlirAttribute layoutAttr =
869+
layout ? *layout : mlirAttributeGetNull();
870+
MlirAttribute memSpaceAttr =
871+
memorySpace ? *memorySpace : mlirAttributeGetNull();
872+
MlirType t =
873+
mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
874+
layoutAttr, memSpaceAttr);
875+
if (mlirTypeIsNull(t))
876+
throw MLIRError("Invalid type", errors.take());
877+
return PyMemRefType(elementType.getContext(), t);
878+
},
879+
nb::arg("shape"), nb::arg("element_type"),
880+
nb::arg("layout") = nb::none(),
881+
nb::arg("memory_space") = nb::none(),
882+
nb::arg("context") = nb::none(), "Create a memref type")
793883
.def_prop_ro(
794884
"layout",
795885
[](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
@@ -842,7 +932,7 @@ class PyUnrankedMemRefType
842932

843933
static void bindDerived(ClassTy &c) {
844934
c.def_static(
845-
"get",
935+
"get_checked",
846936
[](PyType &elementType, PyAttribute *memorySpace,
847937
DefaultingPyLocation loc) {
848938
PyMlirContext::ErrorCapture errors(loc->getContext());
@@ -858,6 +948,32 @@ class PyUnrankedMemRefType
858948
},
859949
nb::arg("element_type"), nb::arg("memory_space").none(),
860950
nb::arg("loc") = nb::none(), "Create a unranked memref type")
951+
.def_prop_ro(
952+
"memory_space",
953+
[](PyUnrankedMemRefType &self)
954+
-> std::optional<nb::typed<nb::object, PyAttribute>> {
955+
MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
956+
if (mlirAttributeIsNull(a))
957+
return std::nullopt;
958+
return PyAttribute(self.getContext(), a).maybeDownCast();
959+
},
960+
"Returns the memory space of the given Unranked MemRef type.")
961+
.def_static(
962+
"get",
963+
[](PyType &elementType, PyAttribute *memorySpace,
964+
DefaultingPyMlirContext context) {
965+
PyMlirContext::ErrorCapture errors(context->getRef());
966+
MlirAttribute memSpaceAttr = {};
967+
if (memorySpace)
968+
memSpaceAttr = *memorySpace;
969+
970+
MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
971+
if (mlirTypeIsNull(t))
972+
throw MLIRError("Invalid type", errors.take());
973+
return PyUnrankedMemRefType(elementType.getContext(), t);
974+
},
975+
nb::arg("element_type"), nb::arg("memory_space").none(),
976+
nb::arg("context") = nb::none(), "Create a unranked memref type")
861977
.def_prop_ro(
862978
"memory_space",
863979
[](PyUnrankedMemRefType &self)

0 commit comments

Comments
 (0)