|
13 | 13 | #include "IRModule.h" |
14 | 14 |
|
15 | 15 | #include "PybindUtils.h" |
16 | | -#include <pybind11/numpy.h> |
17 | 16 |
|
18 | 17 | #include "llvm/ADT/ScopeExit.h" |
19 | 18 | #include "llvm/Support/raw_ostream.h" |
@@ -758,10 +757,103 @@ class PyDenseElementsAttribute |
758 | 757 | throw py::error_already_set(); |
759 | 758 | } |
760 | 759 | auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); |
| 760 | + SmallVector<int64_t> shape; |
| 761 | + if (explicitShape) { |
| 762 | + shape.append(explicitShape->begin(), explicitShape->end()); |
| 763 | + } else { |
| 764 | + shape.append(view.shape, view.shape + view.ndim); |
| 765 | + } |
761 | 766 |
|
| 767 | + MlirAttribute encodingAttr = mlirAttributeGetNull(); |
762 | 768 | MlirContext context = contextWrapper->get(); |
763 | | - MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, |
764 | | - explicitShape, context); |
| 769 | + |
| 770 | + // Detect format codes that are suitable for bulk loading. This includes |
| 771 | + // all byte aligned integer and floating point types up to 8 bytes. |
| 772 | + // Notably, this excludes, bool (which needs to be bit-packed) and |
| 773 | + // other exotics which do not have a direct representation in the buffer |
| 774 | + // protocol (i.e. complex, etc). |
| 775 | + std::optional<MlirType> bulkLoadElementType; |
| 776 | + if (explicitType) { |
| 777 | + bulkLoadElementType = *explicitType; |
| 778 | + } else { |
| 779 | + std::string_view format(view.format); |
| 780 | + if (format == "f") { |
| 781 | + // f32 |
| 782 | + assert(view.itemsize == 4 && "mismatched array itemsize"); |
| 783 | + bulkLoadElementType = mlirF32TypeGet(context); |
| 784 | + } else if (format == "d") { |
| 785 | + // f64 |
| 786 | + assert(view.itemsize == 8 && "mismatched array itemsize"); |
| 787 | + bulkLoadElementType = mlirF64TypeGet(context); |
| 788 | + } else if (format == "e") { |
| 789 | + // f16 |
| 790 | + assert(view.itemsize == 2 && "mismatched array itemsize"); |
| 791 | + bulkLoadElementType = mlirF16TypeGet(context); |
| 792 | + } else if (isSignedIntegerFormat(format)) { |
| 793 | + if (view.itemsize == 4) { |
| 794 | + // i32 |
| 795 | + bulkLoadElementType = signless |
| 796 | + ? mlirIntegerTypeGet(context, 32) |
| 797 | + : mlirIntegerTypeSignedGet(context, 32); |
| 798 | + } else if (view.itemsize == 8) { |
| 799 | + // i64 |
| 800 | + bulkLoadElementType = signless |
| 801 | + ? mlirIntegerTypeGet(context, 64) |
| 802 | + : mlirIntegerTypeSignedGet(context, 64); |
| 803 | + } else if (view.itemsize == 1) { |
| 804 | + // i8 |
| 805 | + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) |
| 806 | + : mlirIntegerTypeSignedGet(context, 8); |
| 807 | + } else if (view.itemsize == 2) { |
| 808 | + // i16 |
| 809 | + bulkLoadElementType = signless |
| 810 | + ? mlirIntegerTypeGet(context, 16) |
| 811 | + : mlirIntegerTypeSignedGet(context, 16); |
| 812 | + } |
| 813 | + } else if (isUnsignedIntegerFormat(format)) { |
| 814 | + if (view.itemsize == 4) { |
| 815 | + // unsigned i32 |
| 816 | + bulkLoadElementType = signless |
| 817 | + ? mlirIntegerTypeGet(context, 32) |
| 818 | + : mlirIntegerTypeUnsignedGet(context, 32); |
| 819 | + } else if (view.itemsize == 8) { |
| 820 | + // unsigned i64 |
| 821 | + bulkLoadElementType = signless |
| 822 | + ? mlirIntegerTypeGet(context, 64) |
| 823 | + : mlirIntegerTypeUnsignedGet(context, 64); |
| 824 | + } else if (view.itemsize == 1) { |
| 825 | + // i8 |
| 826 | + bulkLoadElementType = signless |
| 827 | + ? mlirIntegerTypeGet(context, 8) |
| 828 | + : mlirIntegerTypeUnsignedGet(context, 8); |
| 829 | + } else if (view.itemsize == 2) { |
| 830 | + // i16 |
| 831 | + bulkLoadElementType = signless |
| 832 | + ? mlirIntegerTypeGet(context, 16) |
| 833 | + : mlirIntegerTypeUnsignedGet(context, 16); |
| 834 | + } |
| 835 | + } |
| 836 | + if (!bulkLoadElementType) { |
| 837 | + throw std::invalid_argument( |
| 838 | + std::string("unimplemented array format conversion from format: ") + |
| 839 | + std::string(format)); |
| 840 | + } |
| 841 | + } |
| 842 | + |
| 843 | + MlirType shapedType; |
| 844 | + if (mlirTypeIsAShaped(*bulkLoadElementType)) { |
| 845 | + if (explicitShape) { |
| 846 | + throw std::invalid_argument("Shape can only be specified explicitly " |
| 847 | + "when the type is not a shaped type."); |
| 848 | + } |
| 849 | + shapedType = *bulkLoadElementType; |
| 850 | + } else { |
| 851 | + shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), |
| 852 | + *bulkLoadElementType, encodingAttr); |
| 853 | + } |
| 854 | + size_t rawBufferSize = view.len; |
| 855 | + MlirAttribute attr = |
| 856 | + mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf); |
765 | 857 | if (mlirAttributeIsNull(attr)) { |
766 | 858 | throw std::invalid_argument( |
767 | 859 | "DenseElementsAttr could not be constructed from the given buffer. " |
@@ -871,13 +963,6 @@ class PyDenseElementsAttribute |
871 | 963 | // unsigned i16 |
872 | 964 | return bufferInfo<uint16_t>(shapedType); |
873 | 965 | } |
874 | | - } else if (mlirTypeIsAInteger(elementType) && |
875 | | - mlirIntegerTypeGetWidth(elementType) == 1) { |
876 | | - // i1 / bool |
877 | | - // We can not send the buffer directly back to Python, because the i1 |
878 | | - // values are bitpacked within MLIR. We call numpy's unpackbits function |
879 | | - // to convert the bytes. |
880 | | - return getBooleanBufferFromBitpackedAttribute(); |
881 | 966 | } |
882 | 967 |
|
883 | 968 | // TODO: Currently crashes the program. |
@@ -931,183 +1016,14 @@ class PyDenseElementsAttribute |
931 | 1016 | code == 'q'; |
932 | 1017 | } |
933 | 1018 |
|
934 | | - static MlirType |
935 | | - getShapedType(std::optional<MlirType> bulkLoadElementType, |
936 | | - std::optional<std::vector<int64_t>> explicitShape, |
937 | | - Py_buffer &view) { |
938 | | - SmallVector<int64_t> shape; |
939 | | - if (explicitShape) { |
940 | | - shape.append(explicitShape->begin(), explicitShape->end()); |
941 | | - } else { |
942 | | - shape.append(view.shape, view.shape + view.ndim); |
943 | | - } |
944 | | - |
945 | | - if (mlirTypeIsAShaped(*bulkLoadElementType)) { |
946 | | - if (explicitShape) { |
947 | | - throw std::invalid_argument("Shape can only be specified explicitly " |
948 | | - "when the type is not a shaped type."); |
949 | | - } |
950 | | - return *bulkLoadElementType; |
951 | | - } else { |
952 | | - MlirAttribute encodingAttr = mlirAttributeGetNull(); |
953 | | - return mlirRankedTensorTypeGet(shape.size(), shape.data(), |
954 | | - *bulkLoadElementType, encodingAttr); |
955 | | - } |
956 | | - } |
957 | | - |
958 | | - static MlirAttribute getAttributeFromBuffer( |
959 | | - Py_buffer &view, bool signless, std::optional<PyType> explicitType, |
960 | | - std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) { |
961 | | - // Detect format codes that are suitable for bulk loading. This includes |
962 | | - // all byte aligned integer and floating point types up to 8 bytes. |
963 | | - // Notably, this excludes exotics types which do not have a direct |
964 | | - // representation in the buffer protocol (i.e. complex, etc). |
965 | | - std::optional<MlirType> bulkLoadElementType; |
966 | | - if (explicitType) { |
967 | | - bulkLoadElementType = *explicitType; |
968 | | - } else { |
969 | | - std::string_view format(view.format); |
970 | | - if (format == "f") { |
971 | | - // f32 |
972 | | - assert(view.itemsize == 4 && "mismatched array itemsize"); |
973 | | - bulkLoadElementType = mlirF32TypeGet(context); |
974 | | - } else if (format == "d") { |
975 | | - // f64 |
976 | | - assert(view.itemsize == 8 && "mismatched array itemsize"); |
977 | | - bulkLoadElementType = mlirF64TypeGet(context); |
978 | | - } else if (format == "e") { |
979 | | - // f16 |
980 | | - assert(view.itemsize == 2 && "mismatched array itemsize"); |
981 | | - bulkLoadElementType = mlirF16TypeGet(context); |
982 | | - } else if (format == "?") { |
983 | | - // i1 |
984 | | - // The i1 type needs to be bit-packed, so we will handle it seperately |
985 | | - return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, |
986 | | - context); |
987 | | - } else if (isSignedIntegerFormat(format)) { |
988 | | - if (view.itemsize == 4) { |
989 | | - // i32 |
990 | | - bulkLoadElementType = signless |
991 | | - ? mlirIntegerTypeGet(context, 32) |
992 | | - : mlirIntegerTypeSignedGet(context, 32); |
993 | | - } else if (view.itemsize == 8) { |
994 | | - // i64 |
995 | | - bulkLoadElementType = signless |
996 | | - ? mlirIntegerTypeGet(context, 64) |
997 | | - : mlirIntegerTypeSignedGet(context, 64); |
998 | | - } else if (view.itemsize == 1) { |
999 | | - // i8 |
1000 | | - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) |
1001 | | - : mlirIntegerTypeSignedGet(context, 8); |
1002 | | - } else if (view.itemsize == 2) { |
1003 | | - // i16 |
1004 | | - bulkLoadElementType = signless |
1005 | | - ? mlirIntegerTypeGet(context, 16) |
1006 | | - : mlirIntegerTypeSignedGet(context, 16); |
1007 | | - } |
1008 | | - } else if (isUnsignedIntegerFormat(format)) { |
1009 | | - if (view.itemsize == 4) { |
1010 | | - // unsigned i32 |
1011 | | - bulkLoadElementType = signless |
1012 | | - ? mlirIntegerTypeGet(context, 32) |
1013 | | - : mlirIntegerTypeUnsignedGet(context, 32); |
1014 | | - } else if (view.itemsize == 8) { |
1015 | | - // unsigned i64 |
1016 | | - bulkLoadElementType = signless |
1017 | | - ? mlirIntegerTypeGet(context, 64) |
1018 | | - : mlirIntegerTypeUnsignedGet(context, 64); |
1019 | | - } else if (view.itemsize == 1) { |
1020 | | - // i8 |
1021 | | - bulkLoadElementType = signless |
1022 | | - ? mlirIntegerTypeGet(context, 8) |
1023 | | - : mlirIntegerTypeUnsignedGet(context, 8); |
1024 | | - } else if (view.itemsize == 2) { |
1025 | | - // i16 |
1026 | | - bulkLoadElementType = signless |
1027 | | - ? mlirIntegerTypeGet(context, 16) |
1028 | | - : mlirIntegerTypeUnsignedGet(context, 16); |
1029 | | - } |
1030 | | - } |
1031 | | - if (!bulkLoadElementType) { |
1032 | | - throw std::invalid_argument( |
1033 | | - std::string("unimplemented array format conversion from format: ") + |
1034 | | - std::string(format)); |
1035 | | - } |
1036 | | - } |
1037 | | - |
1038 | | - MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); |
1039 | | - return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); |
1040 | | - } |
1041 | | - |
1042 | | - // There is a complication for boolean numpy arrays, as numpy represents them |
1043 | | - // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans |
1044 | | - // per byte. |
1045 | | - static MlirAttribute getBitpackedAttributeFromBooleanBuffer( |
1046 | | - Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape, |
1047 | | - MlirContext &context) { |
1048 | | - if (llvm::endianness::native != llvm::endianness::little) { |
1049 | | - // Given we have no good way of testing the behavior on big-endian systems |
1050 | | - // we will throw |
1051 | | - throw py::type_error("Constructing a bit-packed MLIR attribute is " |
1052 | | - "unsupported on big-endian systems"); |
1053 | | - } |
1054 | | - |
1055 | | - py::array_t<uint8_t> unpackedArray(view.len, |
1056 | | - static_cast<uint8_t *>(view.buf)); |
1057 | | - |
1058 | | - py::module numpy = py::module::import("numpy"); |
1059 | | - py::object packbits_func = numpy.attr("packbits"); |
1060 | | - py::object packed_booleans = |
1061 | | - packbits_func(unpackedArray, "bitorder"_a = "little"); |
1062 | | - py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request(); |
1063 | | - |
1064 | | - MlirType bitpackedType = |
1065 | | - getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); |
1066 | | - return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, |
1067 | | - pythonBuffer.ptr); |
1068 | | - } |
1069 | | - |
1070 | | - // This does the opposite transformation of |
1071 | | - // `getBitpackedAttributeFromBooleanBuffer` |
1072 | | - py::buffer_info getBooleanBufferFromBitpackedAttribute() { |
1073 | | - if (llvm::endianness::native != llvm::endianness::little) { |
1074 | | - // Given we have no good way of testing the behavior on big-endian systems |
1075 | | - // we will throw |
1076 | | - throw py::type_error("Constructing a numpy array from a MLIR attribute " |
1077 | | - "is unsupported on big-endian systems"); |
1078 | | - } |
1079 | | - |
1080 | | - int64_t numBooleans = mlirElementsAttrGetNumElements(*this); |
1081 | | - int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); |
1082 | | - uint8_t *bitpackedData = static_cast<uint8_t *>( |
1083 | | - const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); |
1084 | | - py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData); |
1085 | | - |
1086 | | - py::module numpy = py::module::import("numpy"); |
1087 | | - py::object unpackbits_func = numpy.attr("unpackbits"); |
1088 | | - py::object unpacked_booleans = |
1089 | | - unpackbits_func(packedArray, "bitorder"_a = "little"); |
1090 | | - py::buffer_info pythonBuffer = |
1091 | | - unpacked_booleans.cast<py::buffer>().request(); |
1092 | | - |
1093 | | - MlirType shapedType = mlirAttributeGetType(*this); |
1094 | | - return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?"); |
1095 | | - } |
1096 | | - |
1097 | 1019 | template <typename Type> |
1098 | 1020 | py::buffer_info bufferInfo(MlirType shapedType, |
1099 | 1021 | const char *explicitFormat = nullptr) { |
| 1022 | + intptr_t rank = mlirShapedTypeGetRank(shapedType); |
1100 | 1023 | // Prepare the data for the buffer_info. |
1101 | | - // Buffer is configured for read-only access inside the `bufferInfo` call. |
| 1024 | + // Buffer is configured for read-only access below. |
1102 | 1025 | Type *data = static_cast<Type *>( |
1103 | 1026 | const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); |
1104 | | - return bufferInfo<Type>(shapedType, data, explicitFormat); |
1105 | | - } |
1106 | | - |
1107 | | - template <typename Type> |
1108 | | - py::buffer_info bufferInfo(MlirType shapedType, Type *data, |
1109 | | - const char *explicitFormat = nullptr) { |
1110 | | - intptr_t rank = mlirShapedTypeGetRank(shapedType); |
1111 | 1027 | // Prepare the shape for the buffer_info. |
1112 | 1028 | SmallVector<intptr_t, 4> shape; |
1113 | 1029 | for (intptr_t i = 0; i < rank; ++i) |
|
0 commit comments