@@ -639,11 +639,16 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
639
639
using PyConcreteType::PyConcreteType;
640
640
641
641
static void bindDerived (ClassTy &c) {
642
- c.def_static (" get" , &PyVectorType::get , nb::arg (" shape" ),
642
+ c.def_static (" get" , &PyVectorType::getChecked , nb::arg (" shape" ),
643
643
nb::arg (" element_type" ), nb::kw_only (),
644
644
nb::arg (" scalable" ) = nb::none (),
645
645
nb::arg (" scalable_dims" ) = nb::none (),
646
646
nb::arg (" loc" ) = nb::none (), " Create a vector type" )
647
+ .def_static (" get_unchecked" , &PyVectorType::get, nb::arg (" shape" ),
648
+ nb::arg (" element_type" ), nb::kw_only (),
649
+ nb::arg (" scalable" ) = nb::none (),
650
+ nb::arg (" scalable_dims" ) = nb::none (),
651
+ nb::arg (" context" ) = nb::none (), " Create a vector type" )
647
652
.def_prop_ro (
648
653
" scalable" ,
649
654
[](MlirType self) { return mlirVectorTypeIsScalable (self); })
@@ -658,10 +663,11 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
658
663
}
659
664
660
665
private:
661
- static PyVectorType get (std::vector<int64_t > shape, PyType &elementType,
662
- std::optional<nb::list> scalable,
663
- std::optional<std::vector<int64_t >> scalableDims,
664
- DefaultingPyLocation loc) {
666
+ static PyVectorType
667
+ getChecked (std::vector<int64_t > shape, PyType &elementType,
668
+ std::optional<nb::list> scalable,
669
+ std::optional<std::vector<int64_t >> scalableDims,
670
+ DefaultingPyLocation loc) {
665
671
if (scalable && scalableDims) {
666
672
throw nb::value_error (" 'scalable' and 'scalable_dims' kwargs "
667
673
" are mutually exclusive." );
@@ -696,6 +702,42 @@ class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
696
702
throw MLIRError (" Invalid type" , errors.take ());
697
703
return PyVectorType (elementType.getContext (), type);
698
704
}
705
+
706
+ static PyVectorType get (std::vector<int64_t > shape, PyType &elementType,
707
+ std::optional<nb::list> scalable,
708
+ std::optional<std::vector<int64_t >> scalableDims,
709
+ DefaultingPyMlirContext context) {
710
+ if (scalable && scalableDims) {
711
+ throw nb::value_error (" 'scalable' and 'scalable_dims' kwargs "
712
+ " are mutually exclusive." );
713
+ }
714
+
715
+ PyMlirContext::ErrorCapture errors (context->getRef ());
716
+ MlirType type;
717
+ if (scalable) {
718
+ if (scalable->size () != shape.size ())
719
+ throw nb::value_error (" Expected len(scalable) == len(shape)." );
720
+
721
+ SmallVector<bool > scalableDimFlags = llvm::to_vector (llvm::map_range (
722
+ *scalable, [](const nb::handle &h) { return nb::cast<bool >(h); }));
723
+ type = mlirVectorTypeGetScalable (shape.size (), shape.data (),
724
+ scalableDimFlags.data (), elementType);
725
+ } else if (scalableDims) {
726
+ SmallVector<bool > scalableDimFlags (shape.size (), false );
727
+ for (int64_t dim : *scalableDims) {
728
+ if (static_cast <size_t >(dim) >= scalableDimFlags.size () || dim < 0 )
729
+ throw nb::value_error (" Scalable dimension index out of bounds." );
730
+ scalableDimFlags[dim] = true ;
731
+ }
732
+ type = mlirVectorTypeGetScalable (shape.size (), shape.data (),
733
+ scalableDimFlags.data (), elementType);
734
+ } else {
735
+ type = mlirVectorTypeGet (shape.size (), shape.data (), elementType);
736
+ }
737
+ if (mlirTypeIsNull (type))
738
+ throw MLIRError (" Invalid type" , errors.take ());
739
+ return PyVectorType (elementType.getContext (), type);
740
+ }
699
741
};
700
742
701
743
// / Ranked Tensor Type subclass - RankedTensorType.
@@ -724,6 +766,22 @@ class PyRankedTensorType
724
766
nb::arg (" shape" ), nb::arg (" element_type" ),
725
767
nb::arg (" encoding" ) = nb::none (), nb::arg (" loc" ) = nb::none (),
726
768
" Create a ranked tensor type" );
769
+ c.def_static (
770
+ " get_unchecked" ,
771
+ [](std::vector<int64_t > shape, PyType &elementType,
772
+ std::optional<PyAttribute> &encodingAttr,
773
+ DefaultingPyMlirContext context) {
774
+ PyMlirContext::ErrorCapture errors (context->getRef ());
775
+ MlirType t = mlirRankedTensorTypeGet (
776
+ shape.size (), shape.data (), elementType,
777
+ encodingAttr ? encodingAttr->get () : mlirAttributeGetNull ());
778
+ if (mlirTypeIsNull (t))
779
+ throw MLIRError (" Invalid type" , errors.take ());
780
+ return PyRankedTensorType (elementType.getContext (), t);
781
+ },
782
+ nb::arg (" shape" ), nb::arg (" element_type" ),
783
+ nb::arg (" encoding" ) = nb::none (), nb::arg (" context" ) = nb::none (),
784
+ " Create a ranked tensor type" );
727
785
c.def_prop_ro (
728
786
" encoding" ,
729
787
[](PyRankedTensorType &self)
@@ -758,6 +816,17 @@ class PyUnrankedTensorType
758
816
},
759
817
nb::arg (" element_type" ), nb::arg (" loc" ) = nb::none (),
760
818
" Create a unranked tensor type" );
819
+ c.def_static (
820
+ " get_unchecked" ,
821
+ [](PyType &elementType, DefaultingPyMlirContext context) {
822
+ PyMlirContext::ErrorCapture errors (context->getRef ());
823
+ MlirType t = mlirUnrankedTensorTypeGet (elementType);
824
+ if (mlirTypeIsNull (t))
825
+ throw MLIRError (" Invalid type" , errors.take ());
826
+ return PyUnrankedTensorType (elementType.getContext (), t);
827
+ },
828
+ nb::arg (" element_type" ), nb::arg (" context" ) = nb::none (),
829
+ " Create a unranked tensor type" );
761
830
}
762
831
};
763
832
@@ -790,6 +859,27 @@ class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
790
859
nb::arg (" shape" ), nb::arg (" element_type" ),
791
860
nb::arg (" layout" ) = nb::none (), nb::arg (" memory_space" ) = nb::none (),
792
861
nb::arg (" loc" ) = nb::none (), " Create a memref type" )
862
+ .def_static (
863
+ " get_unchecked" ,
864
+ [](std::vector<int64_t > shape, PyType &elementType,
865
+ PyAttribute *layout, PyAttribute *memorySpace,
866
+ DefaultingPyMlirContext context) {
867
+ PyMlirContext::ErrorCapture errors (context->getRef ());
868
+ MlirAttribute layoutAttr =
869
+ layout ? *layout : mlirAttributeGetNull ();
870
+ MlirAttribute memSpaceAttr =
871
+ memorySpace ? *memorySpace : mlirAttributeGetNull ();
872
+ MlirType t =
873
+ mlirMemRefTypeGet (elementType, shape.size (), shape.data (),
874
+ layoutAttr, memSpaceAttr);
875
+ if (mlirTypeIsNull (t))
876
+ throw MLIRError (" Invalid type" , errors.take ());
877
+ return PyMemRefType (elementType.getContext (), t);
878
+ },
879
+ nb::arg (" shape" ), nb::arg (" element_type" ),
880
+ nb::arg (" layout" ) = nb::none (),
881
+ nb::arg (" memory_space" ) = nb::none (),
882
+ nb::arg (" context" ) = nb::none (), " Create a memref type" )
793
883
.def_prop_ro (
794
884
" layout" ,
795
885
[](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
@@ -858,6 +948,22 @@ class PyUnrankedMemRefType
858
948
},
859
949
nb::arg (" element_type" ), nb::arg (" memory_space" ).none (),
860
950
nb::arg (" loc" ) = nb::none (), " Create a unranked memref type" )
951
+ .def_static (
952
+ " get_unchecked" ,
953
+ [](PyType &elementType, PyAttribute *memorySpace,
954
+ DefaultingPyMlirContext context) {
955
+ PyMlirContext::ErrorCapture errors (context->getRef ());
956
+ MlirAttribute memSpaceAttr = {};
957
+ if (memorySpace)
958
+ memSpaceAttr = *memorySpace;
959
+
960
+ MlirType t = mlirUnrankedMemRefTypeGet (elementType, memSpaceAttr);
961
+ if (mlirTypeIsNull (t))
962
+ throw MLIRError (" Invalid type" , errors.take ());
963
+ return PyUnrankedMemRefType (elementType.getContext (), t);
964
+ },
965
+ nb::arg (" element_type" ), nb::arg (" memory_space" ).none (),
966
+ nb::arg (" context" ) = nb::none (), " Create a unranked memref type" )
861
967
.def_prop_ro (
862
968
" memory_space" ,
863
969
[](PyUnrankedMemRefType &self)
0 commit comments