Skip to content

Commit 75d34c9

Browse files
committed
Fix python buffer lifetime issues
1 parent b587e28 commit 75d34c9

File tree

1 file changed

+17
-20
lines changed

1 file changed

+17
-20
lines changed

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,13 +1056,17 @@ class PyDenseElementsAttribute
10561056
static_cast<uint8_t *>(view.buf));
10571057

10581058
py::module numpy = py::module::import("numpy");
1059-
py::object packbits_func = numpy.attr("packbits");
1060-
py::object packed_booleans =
1061-
packbits_func(unpackedArray, "bitorder"_a = "little");
1062-
py::buffer_info pythonBuffer = packed_booleans.cast<py::buffer>().request();
1059+
py::object packbitsFunc = numpy.attr("packbits");
1060+
py::object packedBooleans =
1061+
packbitsFunc(unpackedArray, "bitorder"_a = "little");
1062+
py::buffer_info pythonBuffer = packedBooleans.cast<py::buffer>().request();
10631063

10641064
MlirType bitpackedType =
10651065
getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
1066+
assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
1067+
// Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
1068+
// packedBooleans, hence the MlirAttribute will remain valid even when
1069+
// packedBooleans get reclaimed by the end of the function.
10661070
return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
10671071
pythonBuffer.ptr);
10681072
}
@@ -1084,29 +1088,22 @@ class PyDenseElementsAttribute
10841088
py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);
10851089

10861090
py::module numpy = py::module::import("numpy");
1087-
py::object unpackbits_func = numpy.attr("unpackbits");
1088-
py::object unpacked_booleans =
1089-
unpackbits_func(packedArray, "bitorder"_a = "little");
1090-
py::buffer_info pythonBuffer =
1091-
unpacked_booleans.cast<py::buffer>().request();
1092-
1093-
MlirType shapedType = mlirAttributeGetType(*this);
1094-
return bufferInfo<bool>(shapedType, (bool *)pythonBuffer.ptr, "?");
1091+
py::object unpackbitsFunc = numpy.attr("unpackbits");
1092+
py::object unpackedBooleans =
1093+
unpackbitsFunc(packedArray, "bitorder"_a = "little");
1094+
py::buffer pythonBuffer = unpackedBooleans.cast<py::buffer>();
1095+
1096+
// Make sure the returned py::buffer_view claims ownership of the data in
1097+
// `pythonBuffer` so it remains valid when Python reads it
1098+
return pythonBuffer.request();
10951099
}
10961100

10971101
template <typename Type>
10981102
py::buffer_info bufferInfo(MlirType shapedType,
10991103
const char *explicitFormat = nullptr) {
1100-
// Prepare the data for the buffer_info.
1101-
// Buffer is configured for read-only access inside the `bufferInfo` call.
1104+
// Buffer is configured for read-only access below
11021105
Type *data = static_cast<Type *>(
11031106
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1104-
return bufferInfo<Type>(shapedType, data, explicitFormat);
1105-
}
1106-
1107-
template <typename Type>
1108-
py::buffer_info bufferInfo(MlirType shapedType, Type *data,
1109-
const char *explicitFormat = nullptr) {
11101107
intptr_t rank = mlirShapedTypeGetRank(shapedType);
11111108
// Prepare the shape for the buffer_info.
11121109
SmallVector<intptr_t, 4> shape;

0 commit comments

Comments
 (0)