@@ -643,7 +643,12 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
643643 nb::arg (" element_type" ), nb::kw_only (),
644644 nb::arg (" scalable" ) = nb::none (),
645645 nb::arg (" scalable_dims" ) = nb::none (),
646- nb::arg (" loc" ) = nb::none (), " Create a vector type" )
646+ nb::arg (" context" ) = nb::none (), " Create a vector type" )
647+ .def_static (" get" , &PyVectorType::getChecked, nb::arg (" shape" ),
648+ nb::arg (" element_type" ), nb::kw_only (),
649+ nb::arg (" scalable" ) = nb::none (),
650+ nb::arg (" scalable_dims" ) = nb::none (),
651+ nb::arg (" loc" ) = nb::none (), " Create a vector type" )
647652 .def_prop_ro (
648653 " scalable" ,
649654 [](MlirType self) { return mlirVectorTypeIsScalable (self); })
@@ -658,10 +663,11 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
658663 }
659664
660665private:
661- static PyVectorType get (std::vector<int64_t > shape, PyType &elementType,
662- std::optional<nb::list> scalable,
663- std::optional<std::vector<int64_t >> scalableDims,
664- DefaultingPyLocation loc) {
666+ static PyVectorType
667+ getChecked (std::vector<int64_t > shape, PyType &elementType,
668+ std::optional<nb::list> scalable,
669+ std::optional<std::vector<int64_t >> scalableDims,
670+ DefaultingPyLocation loc) {
665671 if (scalable && scalableDims) {
666672 throw nb::value_error (" 'scalable' and 'scalable_dims' kwargs "
667673 " are mutually exclusive." );
@@ -696,6 +702,42 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
696702 throw MLIRError (" Invalid type" , errors.take ());
697703 return PyVectorType (elementType.getContext (), type);
698704 }
705+
706+ static PyVectorType get (std::vector<int64_t > shape, PyType &elementType,
707+ std::optional<nb::list> scalable,
708+ std::optional<std::vector<int64_t >> scalableDims,
709+ DefaultingPyMlirContext context) {
710+ if (scalable && scalableDims) {
711+ throw nb::value_error (" 'scalable' and 'scalable_dims' kwargs "
712+ " are mutually exclusive." );
713+ }
714+
715+ PyMlirContext::ErrorCapture errors (context->getRef ());
716+ MlirType type;
717+ if (scalable) {
718+ if (scalable->size () != shape.size ())
719+ throw nb::value_error (" Expected len(scalable) == len(shape)." );
720+
721+ SmallVector<bool > scalableDimFlags = llvm::to_vector (llvm::map_range (
722+ *scalable, [](const nb::handle &h) { return nb::cast<bool >(h); }));
723+ type = mlirVectorTypeGetScalable (shape.size (), shape.data (),
724+ scalableDimFlags.data (), elementType);
725+ } else if (scalableDims) {
726+ SmallVector<bool > scalableDimFlags (shape.size (), false );
727+ for (int64_t dim : *scalableDims) {
728+ if (static_cast <size_t >(dim) >= scalableDimFlags.size () || dim < 0 )
729+ throw nb::value_error (" Scalable dimension index out of bounds." );
730+ scalableDimFlags[dim] = true ;
731+ }
732+ type = mlirVectorTypeGetScalable (shape.size (), shape.data (),
733+ scalableDimFlags.data (), elementType);
734+ } else {
735+ type = mlirVectorTypeGet (shape.size (), shape.data (), elementType);
736+ }
737+ if (mlirTypeIsNull (type))
738+ throw MLIRError (" Invalid type" , errors.take ());
739+ return PyVectorType (elementType.getContext (), type);
740+ }
699741};
700742
701743// / Ranked Tensor Type subclass - RankedTensorType.
@@ -710,7 +752,7 @@ class PyRankedTensorType
710752
711753 static void bindDerived (ClassTy &c) {
712754 c.def_static (
713- " get " ,
755+ " get_checked " ,
714756 [](std::vector<int64_t > shape, PyType &elementType,
715757 std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
716758 PyMlirContext::ErrorCapture errors (loc->getContext ());
@@ -724,6 +766,22 @@ class PyRankedTensorType
724766 nb::arg (" shape" ), nb::arg (" element_type" ),
725767 nb::arg (" encoding" ) = nb::none (), nb::arg (" loc" ) = nb::none (),
726768 " Create a ranked tensor type" );
769+ c.def_static (
770+ " get" ,
771+ [](std::vector<int64_t > shape, PyType &elementType,
772+ std::optional<PyAttribute> &encodingAttr,
773+ DefaultingPyMlirContext context) {
774+ PyMlirContext::ErrorCapture errors (context->getRef ());
775+ MlirType t = mlirRankedTensorTypeGet (
776+ shape.size (), shape.data (), elementType,
777+ encodingAttr ? encodingAttr->get () : mlirAttributeGetNull ());
778+ if (mlirTypeIsNull (t))
779+ throw MLIRError (" Invalid type" , errors.take ());
780+ return PyRankedTensorType (elementType.getContext (), t);
781+ },
782+ nb::arg (" shape" ), nb::arg (" element_type" ),
783+ nb::arg (" encoding" ) = nb::none (), nb::arg (" context" ) = nb::none (),
784+ " Create a ranked tensor type" );
727785 c.def_prop_ro (
728786 " encoding" ,
729787 [](PyRankedTensorType &self)
@@ -748,7 +806,7 @@ class PyUnrankedTensorType
748806
749807 static void bindDerived (ClassTy &c) {
750808 c.def_static (
751- " get " ,
809+ " get_checked " ,
752810 [](PyType &elementType, DefaultingPyLocation loc) {
753811 PyMlirContext::ErrorCapture errors (loc->getContext ());
754812 MlirType t = mlirUnrankedTensorTypeGetChecked (loc, elementType);
@@ -758,6 +816,17 @@ class PyUnrankedTensorType
758816 },
759817 nb::arg (" element_type" ), nb::arg (" loc" ) = nb::none (),
760818 " Create a unranked tensor type" );
819+ c.def_static (
820+ " get" ,
821+ [](PyType &elementType, DefaultingPyMlirContext context) {
822+ PyMlirContext::ErrorCapture errors (context->getRef ());
823+ MlirType t = mlirUnrankedTensorTypeGet (elementType);
824+ if (mlirTypeIsNull (t))
825+ throw MLIRError (" Invalid type" , errors.take ());
826+ return PyUnrankedTensorType (elementType.getContext (), t);
827+ },
828+ nb::arg (" element_type" ), nb::arg (" context" ) = nb::none (),
829+ " Create a unranked tensor type" );
761830 }
762831};
763832
@@ -772,7 +841,7 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
772841
773842 static void bindDerived (ClassTy &c) {
774843 c.def_static (
775- " get " ,
844+ " get_checked " ,
776845 [](std::vector<int64_t > shape, PyType &elementType,
777846 PyAttribute *layout, PyAttribute *memorySpace,
778847 DefaultingPyLocation loc) {
@@ -790,6 +859,27 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
790859 nb::arg (" shape" ), nb::arg (" element_type" ),
791860 nb::arg (" layout" ) = nb::none (), nb::arg (" memory_space" ) = nb::none (),
792861 nb::arg (" loc" ) = nb::none (), " Create a memref type" )
862+ .def_static (
863+ " get" ,
864+ [](std::vector<int64_t > shape, PyType &elementType,
865+ PyAttribute *layout, PyAttribute *memorySpace,
866+ DefaultingPyMlirContext context) {
867+ PyMlirContext::ErrorCapture errors (context->getRef ());
868+ MlirAttribute layoutAttr =
869+ layout ? *layout : mlirAttributeGetNull ();
870+ MlirAttribute memSpaceAttr =
871+ memorySpace ? *memorySpace : mlirAttributeGetNull ();
872+ MlirType t =
873+ mlirMemRefTypeGet (elementType, shape.size (), shape.data (),
874+ layoutAttr, memSpaceAttr);
875+ if (mlirTypeIsNull (t))
876+ throw MLIRError (" Invalid type" , errors.take ());
877+ return PyMemRefType (elementType.getContext (), t);
878+ },
879+ nb::arg (" shape" ), nb::arg (" element_type" ),
880+ nb::arg (" layout" ) = nb::none (),
881+ nb::arg (" memory_space" ) = nb::none (),
882+ nb::arg (" context" ) = nb::none (), " Create a memref type" )
793883 .def_prop_ro (
794884 " layout" ,
795885 [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
@@ -842,7 +932,7 @@ class PyUnrankedMemRefType
842932
843933 static void bindDerived (ClassTy &c) {
844934 c.def_static (
845- " get " ,
935+ " get_checked " ,
846936 [](PyType &elementType, PyAttribute *memorySpace,
847937 DefaultingPyLocation loc) {
848938 PyMlirContext::ErrorCapture errors (loc->getContext ());
@@ -858,6 +948,32 @@ class PyUnrankedMemRefType
858948 },
859949 nb::arg (" element_type" ), nb::arg (" memory_space" ).none (),
860950 nb::arg (" loc" ) = nb::none (), " Create a unranked memref type" )
951+ .def_prop_ro (
952+ " memory_space" ,
953+ [](PyUnrankedMemRefType &self)
954+ -> std::optional<nb::typed<nb::object, PyAttribute>> {
955+ MlirAttribute a = mlirUnrankedMemrefGetMemorySpace (self);
956+ if (mlirAttributeIsNull (a))
957+ return std::nullopt ;
958+ return PyAttribute (self.getContext (), a).maybeDownCast ();
959+ },
960+ " Returns the memory space of the given Unranked MemRef type." )
961+ .def_static (
962+ " get" ,
963+ [](PyType &elementType, PyAttribute *memorySpace,
964+ DefaultingPyMlirContext context) {
965+ PyMlirContext::ErrorCapture errors (context->getRef ());
966+ MlirAttribute memSpaceAttr = {};
967+ if (memorySpace)
968+ memSpaceAttr = *memorySpace;
969+
970+ MlirType t = mlirUnrankedMemRefTypeGet (elementType, memSpaceAttr);
971+ if (mlirTypeIsNull (t))
972+ throw MLIRError (" Invalid type" , errors.take ());
973+ return PyUnrankedMemRefType (elementType.getContext (), t);
974+ },
975+ nb::arg (" element_type" ), nb::arg (" memory_space" ).none (),
976+ nb::arg (" context" ) = nb::none (), " Create a unranked memref type" )
861977 .def_prop_ro (
862978 " memory_space" ,
863979 [](PyUnrankedMemRefType &self)
0 commit comments