Skip to content

Commit e480fa8

Browse files
committed
Fix the boolean array padding, type and shape
1 parent b11026a commit e480fa8

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

mlir/lib/Bindings/Python/IRAttributes.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,12 +1089,30 @@ class PyDenseElementsAttribute
10891089

10901090
py::module numpy = py::module::import("numpy");
10911091
py::object unpackbitsFunc = numpy.attr("unpackbits");
1092-
py::object unpackedBooleans =
1092+
py::object equalFunc = numpy.attr("equal");
1093+
py::object reshapeFunc = numpy.attr("reshape");
1094+
py::array unpackedBooleans =
10931095
unpackbitsFunc(packedArray, "bitorder"_a = "little");
1094-
py::buffer pythonBuffer = unpackedBooleans.cast<py::buffer>();
1096+
1097+
// Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
1098+
// We need to:
1099+
// 1. Slice away the padded bits
1100+
// 2. Make the boolean array have the correct shape
1101+
// 3. Convert the array to a boolean array
1102+
unpackedBooleans = unpackedBooleans[py::slice(0, numBooleans, 1)];
1103+
unpackedBooleans = equalFunc(unpackedBooleans, 1);
1104+
1105+
std::vector<intptr_t> shape;
1106+
MlirType shapedType = mlirAttributeGetType(*this);
1107+
intptr_t rank = mlirShapedTypeGetRank(shapedType);
1108+
for (intptr_t i = 0; i < rank; ++i) {
1109+
shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
1110+
}
1111+
unpackedBooleans = reshapeFunc(unpackedBooleans, shape);
10951112

10961113
// Make sure the returned py::buffer_view claims ownership of the data in
10971114
// `pythonBuffer` so it remains valid when Python reads it
1115+
py::buffer pythonBuffer = unpackedBooleans.cast<py::buffer>();
10981116
return pythonBuffer.request();
10991117
}
11001118

0 commit comments

Comments
 (0)