Skip to content

Commit 34d58d2

Browse files
maksleventalmahesh-attarde
authored andcommitted
[MLIR][Python] add unchecked gettors (llvm#160954)
Some of the current gettors require passing locations (i.e., there be an active location) because they're using the "checked" APIs. This PR adds "unchecked" gettors which only require an active context.
1 parent 18d3a58 commit 34d58d2

File tree

4 files changed

+162
-23
lines changed

4 files changed

+162
-23
lines changed

mlir/lib/Bindings/Python/DialectLLVM.cpp

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,37 @@ static void populateDialectLLVMSubmodule(const nanobind::module_ &m) {
3333
auto llvmStructType =
3434
mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType);
3535

36-
llvmStructType.def_classmethod(
37-
"get_literal",
38-
[](const nb::object &cls, const std::vector<MlirType> &elements,
39-
bool packed, MlirLocation loc) {
40-
CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
41-
42-
MlirType type = mlirLLVMStructTypeLiteralGetChecked(
43-
loc, elements.size(), elements.data(), packed);
44-
if (mlirTypeIsNull(type)) {
45-
throw nb::value_error(scope.takeMessage().c_str());
46-
}
47-
return cls(type);
48-
},
49-
"cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
50-
"loc"_a = nb::none());
36+
llvmStructType
37+
.def_classmethod(
38+
"get_literal",
39+
[](const nb::object &cls, const std::vector<MlirType> &elements,
40+
bool packed, MlirLocation loc) {
41+
CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc));
42+
43+
MlirType type = mlirLLVMStructTypeLiteralGetChecked(
44+
loc, elements.size(), elements.data(), packed);
45+
if (mlirTypeIsNull(type)) {
46+
throw nb::value_error(scope.takeMessage().c_str());
47+
}
48+
return cls(type);
49+
},
50+
"cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
51+
"loc"_a = nb::none())
52+
.def_classmethod(
53+
"get_literal_unchecked",
54+
[](const nb::object &cls, const std::vector<MlirType> &elements,
55+
bool packed, MlirContext context) {
56+
CollectDiagnosticsToStringScope scope(context);
57+
58+
MlirType type = mlirLLVMStructTypeLiteralGet(
59+
context, elements.size(), elements.data(), packed);
60+
if (mlirTypeIsNull(type)) {
61+
throw nb::value_error(scope.takeMessage().c_str());
62+
}
63+
return cls(type);
64+
},
65+
"cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false,
66+
"context"_a = nb::none());
5167

5268
llvmStructType.def_classmethod(
5369
"get_identified",

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,18 @@ class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
575575
},
576576
nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(),
577577
"Gets an uniqued float point attribute associated to a type");
578+
c.def_static(
579+
"get_unchecked",
580+
[](PyType &type, double value, DefaultingPyMlirContext context) {
581+
PyMlirContext::ErrorCapture errors(context->getRef());
582+
MlirAttribute attr =
583+
mlirFloatAttrDoubleGet(context.get()->get(), type, value);
584+
if (mlirAttributeIsNull(attr))
585+
throw MLIRError("Invalid attribute", errors.take());
586+
return PyFloatAttribute(type.getContext(), attr);
587+
},
588+
nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(),
589+
"Gets an uniqued float point attribute associated to a type");
578590
c.def_static(
579591
"get_f32",
580592
[](double value, DefaultingPyMlirContext context) {

mlir/lib/Bindings/Python/IRTypes.cpp

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -639,11 +639,16 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
639639
using PyConcreteType::PyConcreteType;
640640

641641
static void bindDerived(ClassTy &c) {
642-
c.def_static("get", &PyVectorType::get, nb::arg("shape"),
642+
c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
643643
nb::arg("element_type"), nb::kw_only(),
644644
nb::arg("scalable") = nb::none(),
645645
nb::arg("scalable_dims") = nb::none(),
646646
nb::arg("loc") = nb::none(), "Create a vector type")
647+
.def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
648+
nb::arg("element_type"), nb::kw_only(),
649+
nb::arg("scalable") = nb::none(),
650+
nb::arg("scalable_dims") = nb::none(),
651+
nb::arg("context") = nb::none(), "Create a vector type")
647652
.def_prop_ro(
648653
"scalable",
649654
[](MlirType self) { return mlirVectorTypeIsScalable(self); })
@@ -658,10 +663,11 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
658663
}
659664

660665
private:
661-
static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
662-
std::optional<nb::list> scalable,
663-
std::optional<std::vector<int64_t>> scalableDims,
664-
DefaultingPyLocation loc) {
666+
static PyVectorType
667+
getChecked(std::vector<int64_t> shape, PyType &elementType,
668+
std::optional<nb::list> scalable,
669+
std::optional<std::vector<int64_t>> scalableDims,
670+
DefaultingPyLocation loc) {
665671
if (scalable && scalableDims) {
666672
throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
667673
"are mutually exclusive.");
@@ -696,6 +702,42 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
696702
throw MLIRError("Invalid type", errors.take());
697703
return PyVectorType(elementType.getContext(), type);
698704
}
705+
706+
static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
707+
std::optional<nb::list> scalable,
708+
std::optional<std::vector<int64_t>> scalableDims,
709+
DefaultingPyMlirContext context) {
710+
if (scalable && scalableDims) {
711+
throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
712+
"are mutually exclusive.");
713+
}
714+
715+
PyMlirContext::ErrorCapture errors(context->getRef());
716+
MlirType type;
717+
if (scalable) {
718+
if (scalable->size() != shape.size())
719+
throw nb::value_error("Expected len(scalable) == len(shape).");
720+
721+
SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
722+
*scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
723+
type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
724+
scalableDimFlags.data(), elementType);
725+
} else if (scalableDims) {
726+
SmallVector<bool> scalableDimFlags(shape.size(), false);
727+
for (int64_t dim : *scalableDims) {
728+
if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
729+
throw nb::value_error("Scalable dimension index out of bounds.");
730+
scalableDimFlags[dim] = true;
731+
}
732+
type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
733+
scalableDimFlags.data(), elementType);
734+
} else {
735+
type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
736+
}
737+
if (mlirTypeIsNull(type))
738+
throw MLIRError("Invalid type", errors.take());
739+
return PyVectorType(elementType.getContext(), type);
740+
}
699741
};
700742

701743
/// Ranked Tensor Type subclass - RankedTensorType.
@@ -724,6 +766,22 @@ class PyRankedTensorType
724766
nb::arg("shape"), nb::arg("element_type"),
725767
nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
726768
"Create a ranked tensor type");
769+
c.def_static(
770+
"get_unchecked",
771+
[](std::vector<int64_t> shape, PyType &elementType,
772+
std::optional<PyAttribute> &encodingAttr,
773+
DefaultingPyMlirContext context) {
774+
PyMlirContext::ErrorCapture errors(context->getRef());
775+
MlirType t = mlirRankedTensorTypeGet(
776+
shape.size(), shape.data(), elementType,
777+
encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
778+
if (mlirTypeIsNull(t))
779+
throw MLIRError("Invalid type", errors.take());
780+
return PyRankedTensorType(elementType.getContext(), t);
781+
},
782+
nb::arg("shape"), nb::arg("element_type"),
783+
nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
784+
"Create a ranked tensor type");
727785
c.def_prop_ro(
728786
"encoding",
729787
[](PyRankedTensorType &self)
@@ -758,6 +816,17 @@ class PyUnrankedTensorType
758816
},
759817
nb::arg("element_type"), nb::arg("loc") = nb::none(),
760818
"Create a unranked tensor type");
819+
c.def_static(
820+
"get_unchecked",
821+
[](PyType &elementType, DefaultingPyMlirContext context) {
822+
PyMlirContext::ErrorCapture errors(context->getRef());
823+
MlirType t = mlirUnrankedTensorTypeGet(elementType);
824+
if (mlirTypeIsNull(t))
825+
throw MLIRError("Invalid type", errors.take());
826+
return PyUnrankedTensorType(elementType.getContext(), t);
827+
},
828+
nb::arg("element_type"), nb::arg("context") = nb::none(),
829+
"Create a unranked tensor type");
761830
}
762831
};
763832

@@ -790,6 +859,27 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
790859
nb::arg("shape"), nb::arg("element_type"),
791860
nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
792861
nb::arg("loc") = nb::none(), "Create a memref type")
862+
.def_static(
863+
"get_unchecked",
864+
[](std::vector<int64_t> shape, PyType &elementType,
865+
PyAttribute *layout, PyAttribute *memorySpace,
866+
DefaultingPyMlirContext context) {
867+
PyMlirContext::ErrorCapture errors(context->getRef());
868+
MlirAttribute layoutAttr =
869+
layout ? *layout : mlirAttributeGetNull();
870+
MlirAttribute memSpaceAttr =
871+
memorySpace ? *memorySpace : mlirAttributeGetNull();
872+
MlirType t =
873+
mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
874+
layoutAttr, memSpaceAttr);
875+
if (mlirTypeIsNull(t))
876+
throw MLIRError("Invalid type", errors.take());
877+
return PyMemRefType(elementType.getContext(), t);
878+
},
879+
nb::arg("shape"), nb::arg("element_type"),
880+
nb::arg("layout") = nb::none(),
881+
nb::arg("memory_space") = nb::none(),
882+
nb::arg("context") = nb::none(), "Create a memref type")
793883
.def_prop_ro(
794884
"layout",
795885
[](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
@@ -858,6 +948,22 @@ class PyUnrankedMemRefType
858948
},
859949
nb::arg("element_type"), nb::arg("memory_space").none(),
860950
nb::arg("loc") = nb::none(), "Create a unranked memref type")
951+
.def_static(
952+
"get_unchecked",
953+
[](PyType &elementType, PyAttribute *memorySpace,
954+
DefaultingPyMlirContext context) {
955+
PyMlirContext::ErrorCapture errors(context->getRef());
956+
MlirAttribute memSpaceAttr = {};
957+
if (memorySpace)
958+
memSpaceAttr = *memorySpace;
959+
960+
MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
961+
if (mlirTypeIsNull(t))
962+
throw MLIRError("Invalid type", errors.take());
963+
return PyUnrankedMemRefType(elementType.getContext(), t);
964+
},
965+
nb::arg("element_type"), nb::arg("memory_space").none(),
966+
nb::arg("context") = nb::none(), "Create a unranked memref type")
861967
.def_prop_ro(
862968
"memory_space",
863969
[](PyUnrankedMemRefType &self)

mlir/test/python/ir/builtin_types.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,16 @@ def testAbstractShapedType():
371371
# CHECK-LABEL: TEST: testVectorType
372372
@run
373373
def testVectorType():
374+
shape = [2, 3]
375+
with Context():
376+
f32 = F32Type.get()
377+
# CHECK: unchecked vector type: vector<2x3xf32>
378+
print("unchecked vector type:", VectorType.get_unchecked(shape, f32))
379+
374380
with Context(), Location.unknown():
375381
f32 = F32Type.get()
376-
shape = [2, 3]
377-
# CHECK: vector type: vector<2x3xf32>
378-
print("vector type:", VectorType.get(shape, f32))
382+
# CHECK: checked vector type: vector<2x3xf32>
383+
print("checked vector type:", VectorType.get(shape, f32))
379384

380385
none = NoneType.get()
381386
try:

0 commit comments

Comments
 (0)