@@ -1056,13 +1056,17 @@ class PyDenseElementsAttribute
10561056 static_cast <uint8_t *>(view.buf ));
10571057
10581058 py::module numpy = py::module::import (" numpy" );
1059- py::object packbits_func = numpy.attr (" packbits" );
1060- py::object packed_booleans =
1061- packbits_func (unpackedArray, " bitorder" _a = " little" );
1062- py::buffer_info pythonBuffer = packed_booleans .cast <py::buffer>().request ();
1059+ py::object packbitsFunc = numpy.attr (" packbits" );
1060+ py::object packedBooleans =
1061+ packbitsFunc (unpackedArray, " bitorder" _a = " little" );
1062+ py::buffer_info pythonBuffer = packedBooleans .cast <py::buffer>().request ();
10631063
10641064 MlirType bitpackedType =
10651065 getShapedType (mlirIntegerTypeGet (context, 1 ), explicitShape, view);
1066+ assert (pythonBuffer.itemsize == 1 && " Packbits must return uint8" );
1067+ // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
1068+ // packedBooleans, hence the MlirAttribute will remain valid even when
1069+ // packedBooleans get reclaimed by the end of the function.
10661070 return mlirDenseElementsAttrRawBufferGet (bitpackedType, pythonBuffer.size ,
10671071 pythonBuffer.ptr );
10681072 }
@@ -1084,29 +1088,22 @@ class PyDenseElementsAttribute
10841088 py::array_t <uint8_t > packedArray (numBitpackedBytes, bitpackedData);
10851089
10861090 py::module numpy = py::module::import (" numpy" );
1087- py::object unpackbits_func = numpy.attr (" unpackbits" );
1088- py::object unpacked_booleans =
1089- unpackbits_func (packedArray, " bitorder" _a = " little" );
1090- py::buffer_info pythonBuffer =
1091- unpacked_booleans. cast <py::buffer>(). request ();
1092-
1093- MlirType shapedType = mlirAttributeGetType (* this );
1094- return bufferInfo< bool >(shapedType, ( bool *) pythonBuffer.ptr , " ? " );
1091+ py::object unpackbitsFunc = numpy.attr (" unpackbits" );
1092+ py::object unpackedBooleans =
1093+ unpackbitsFunc (packedArray, " bitorder" _a = " little" );
1094+ py::buffer pythonBuffer = unpackedBooleans. cast <py::buffer>();
1095+
1096+ // Make sure the returned py::buffer_view claims ownership of the data in
1097+ // `pythonBuffer` so it remains valid when Python reads it
1098+ return pythonBuffer.request ( );
10951099 }
10961100
10971101 template <typename Type>
10981102 py::buffer_info bufferInfo (MlirType shapedType,
10991103 const char *explicitFormat = nullptr ) {
1100- // Prepare the data for the buffer_info.
1101- // Buffer is configured for read-only access inside the `bufferInfo` call.
1104+ // Buffer is configured for read-only access below
11021105 Type *data = static_cast <Type *>(
11031106 const_cast <void *>(mlirDenseElementsAttrGetRawData (*this )));
1104- return bufferInfo<Type>(shapedType, data, explicitFormat);
1105- }
1106-
1107- template <typename Type>
1108- py::buffer_info bufferInfo (MlirType shapedType, Type *data,
1109- const char *explicitFormat = nullptr ) {
11101107 intptr_t rank = mlirShapedTypeGetRank (shapedType);
11111108 // Prepare the shape for the buffer_info.
11121109 SmallVector<intptr_t , 4 > shape;
0 commit comments