@@ -1089,12 +1089,30 @@ class PyDenseElementsAttribute
10891089
10901090 py::module numpy = py::module::import (" numpy" );
10911091 py::object unpackbitsFunc = numpy.attr (" unpackbits" );
1092- py::object unpackedBooleans =
1092+ py::object equalFunc = numpy.attr (" equal" );
1093+ py::object reshapeFunc = numpy.attr (" reshape" );
1094+ py::array unpackedBooleans =
10931095 unpackbitsFunc (packedArray, " bitorder" _a = " little" );
1094- py::buffer pythonBuffer = unpackedBooleans.cast <py::buffer>();
1096+
1097+ // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
1098+ // We need to:
1099+ // 1. Slice away the padded bits
1100+ // 2. Make the boolean array have the correct shape
1101+ // 3. Convert the array to a boolean array
1102+ unpackedBooleans = unpackedBooleans[py::slice (0 , numBooleans, 1 )];
1103+ unpackedBooleans = equalFunc (unpackedBooleans, 1 );
1104+
1105+ std::vector<intptr_t > shape;
1106+ MlirType shapedType = mlirAttributeGetType (*this );
1107+ intptr_t rank = mlirShapedTypeGetRank (shapedType);
1108+ for (intptr_t i = 0 ; i < rank; ++i) {
1109+ shape.push_back (mlirShapedTypeGetDimSize (shapedType, i));
1110+ }
1111+ unpackedBooleans = reshapeFunc (unpackedBooleans, shape);
10951112
10961113 // Make sure the returned py::buffer_view claims ownership of the data in
10971114 // `pythonBuffer` so it remains valid when Python reads it
1115+ py::buffer pythonBuffer = unpackedBooleans.cast <py::buffer>();
10981116 return pythonBuffer.request ();
10991117 }
11001118
0 commit comments