Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 195 additions & 95 deletions mlir/lib/Bindings/Python/IRAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "IRModule.h"

#include "PybindUtils.h"
#include <pybind11/numpy.h>

#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/raw_ostream.h"
Expand Down Expand Up @@ -757,103 +758,10 @@ class PyDenseElementsAttribute
throw py::error_already_set();
}
auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); });
SmallVector<int64_t> shape;
if (explicitShape) {
shape.append(explicitShape->begin(), explicitShape->end());
} else {
shape.append(view.shape, view.shape + view.ndim);
}

MlirAttribute encodingAttr = mlirAttributeGetNull();
MlirContext context = contextWrapper->get();

// Detect format codes that are suitable for bulk loading. This includes
// all byte aligned integer and floating point types up to 8 bytes.
// Notably, this excludes, bool (which needs to be bit-packed) and
// other exotics which do not have a direct representation in the buffer
// protocol (i.e. complex, etc).
std::optional<MlirType> bulkLoadElementType;
if (explicitType) {
bulkLoadElementType = *explicitType;
} else {
std::string_view format(view.format);
if (format == "f") {
// f32
assert(view.itemsize == 4 && "mismatched array itemsize");
bulkLoadElementType = mlirF32TypeGet(context);
} else if (format == "d") {
// f64
assert(view.itemsize == 8 && "mismatched array itemsize");
bulkLoadElementType = mlirF64TypeGet(context);
} else if (format == "e") {
// f16
assert(view.itemsize == 2 && "mismatched array itemsize");
bulkLoadElementType = mlirF16TypeGet(context);
} else if (isSignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeSignedGet(context, 32);
} else if (view.itemsize == 8) {
// i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeSignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeSignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeSignedGet(context, 16);
}
} else if (isUnsignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// unsigned i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeUnsignedGet(context, 32);
} else if (view.itemsize == 8) {
// unsigned i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeUnsignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeUnsignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeUnsignedGet(context, 16);
}
}
if (!bulkLoadElementType) {
throw std::invalid_argument(
std::string("unimplemented array format conversion from format: ") +
std::string(format));
}
}

MlirType shapedType;
if (mlirTypeIsAShaped(*bulkLoadElementType)) {
if (explicitShape) {
throw std::invalid_argument("Shape can only be specified explicitly "
"when the type is not a shaped type.");
}
shapedType = *bulkLoadElementType;
} else {
shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(),
*bulkLoadElementType, encodingAttr);
}
size_t rawBufferSize = view.len;
MlirAttribute attr =
mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf);
MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType,
explicitShape, context);
if (mlirAttributeIsNull(attr)) {
throw std::invalid_argument(
"DenseElementsAttr could not be constructed from the given buffer. "
Expand Down Expand Up @@ -963,6 +871,13 @@ class PyDenseElementsAttribute
// unsigned i16
return bufferInfo<uint16_t>(shapedType);
}
} else if (mlirTypeIsAInteger(elementType) &&
mlirIntegerTypeGetWidth(elementType) == 1) {
// i1 / bool
// We can not send the buffer directly back to Python, because the i1
// values are bitpacked within MLIR. We call numpy's unpackbits function
// to convert the bytes.
return getBooleanBufferFromBitpackedAttribute();
}

// TODO: Currently crashes the program.
Expand Down Expand Up @@ -1016,6 +931,191 @@ class PyDenseElementsAttribute
code == 'q';
}

static MlirType
getShapedType(std::optional<MlirType> bulkLoadElementType,
std::optional<std::vector<int64_t>> explicitShape,
Py_buffer &view) {
SmallVector<int64_t> shape;
if (explicitShape) {
shape.append(explicitShape->begin(), explicitShape->end());
} else {
shape.append(view.shape, view.shape + view.ndim);
}

if (mlirTypeIsAShaped(*bulkLoadElementType)) {
if (explicitShape) {
throw std::invalid_argument("Shape can only be specified explicitly "
"when the type is not a shaped type.");
}
return *bulkLoadElementType;
} else {
MlirAttribute encodingAttr = mlirAttributeGetNull();
return mlirRankedTensorTypeGet(shape.size(), shape.data(),
*bulkLoadElementType, encodingAttr);
}
}

static MlirAttribute getAttributeFromBuffer(
Py_buffer &view, bool signless, std::optional<PyType> explicitType,
std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) {
// Detect format codes that are suitable for bulk loading. This includes
// all byte aligned integer and floating point types up to 8 bytes.
// Notably, this excludes exotics types which do not have a direct
// representation in the buffer protocol (i.e. complex, etc).
std::optional<MlirType> bulkLoadElementType;
if (explicitType) {
bulkLoadElementType = *explicitType;
} else {
std::string_view format(view.format);
if (format == "f") {
// f32
assert(view.itemsize == 4 && "mismatched array itemsize");
bulkLoadElementType = mlirF32TypeGet(context);
} else if (format == "d") {
// f64
assert(view.itemsize == 8 && "mismatched array itemsize");
bulkLoadElementType = mlirF64TypeGet(context);
} else if (format == "e") {
// f16
assert(view.itemsize == 2 && "mismatched array itemsize");
bulkLoadElementType = mlirF16TypeGet(context);
} else if (format == "?") {
// i1
// The i1 type needs to be bit-packed, so we will handle it seperately
return getBitpackedAttributeFromBooleanBuffer(view, explicitShape,
context);
} else if (isSignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeSignedGet(context, 32);
} else if (view.itemsize == 8) {
// i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeSignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeSignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeSignedGet(context, 16);
}
} else if (isUnsignedIntegerFormat(format)) {
if (view.itemsize == 4) {
// unsigned i32
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 32)
: mlirIntegerTypeUnsignedGet(context, 32);
} else if (view.itemsize == 8) {
// unsigned i64
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 64)
: mlirIntegerTypeUnsignedGet(context, 64);
} else if (view.itemsize == 1) {
// i8
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 8)
: mlirIntegerTypeUnsignedGet(context, 8);
} else if (view.itemsize == 2) {
// i16
bulkLoadElementType = signless
? mlirIntegerTypeGet(context, 16)
: mlirIntegerTypeUnsignedGet(context, 16);
}
}
if (!bulkLoadElementType) {
throw std::invalid_argument(
std::string("unimplemented array format conversion from format: ") +
std::string(format));
}
}

MlirType type = getShapedType(bulkLoadElementType, explicitShape, view);
return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf);
}

// There is a complication for boolean numpy arrays, as numpy represents them
// as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans
// per byte.
static MlirAttribute getBitpackedAttributeFromBooleanBuffer(
Py_buffer &view, std::optional<std::vector<int64_t>> explicitShape,
MlirContext &context) {
if (llvm::endianness::native != llvm::endianness::little) {
// Given we have no good way of testing the behavior on big-endian systems
// we will throw
throw py::type_error("Constructing a bit-packed MLIR attribute is "
"unsupported on big-endian systems");
}

py::array_t<uint8_t> unpackedArray(view.len,
static_cast<uint8_t *>(view.buf));

py::module numpy = py::module::import("numpy");
py::object packbitsFunc = numpy.attr("packbits");
py::object packedBooleans =
packbitsFunc(unpackedArray, "bitorder"_a = "little");
py::buffer_info pythonBuffer = packedBooleans.cast<py::buffer>().request();

MlirType bitpackedType =
getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view);
assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8");
// Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of
// packedBooleans, hence the MlirAttribute will remain valid even when
// packedBooleans get reclaimed by the end of the function.
return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size,
pythonBuffer.ptr);
}

// This does the opposite transformation of
// `getBitpackedAttributeFromBooleanBuffer`
py::buffer_info getBooleanBufferFromBitpackedAttribute() {
if (llvm::endianness::native != llvm::endianness::little) {
// Given we have no good way of testing the behavior on big-endian systems
// we will throw
throw py::type_error("Constructing a numpy array from a MLIR attribute "
"is unsupported on big-endian systems");
}

int64_t numBooleans = mlirElementsAttrGetNumElements(*this);
int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8);
uint8_t *bitpackedData = static_cast<uint8_t *>(
const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
py::array_t<uint8_t> packedArray(numBitpackedBytes, bitpackedData);

py::module numpy = py::module::import("numpy");
py::object unpackbitsFunc = numpy.attr("unpackbits");
py::object equalFunc = numpy.attr("equal");
py::object reshapeFunc = numpy.attr("reshape");
py::array unpackedBooleans =
unpackbitsFunc(packedArray, "bitorder"_a = "little");

// Unpackbits operates on bytes and gives back a flat 0 / 1 integer array.
// We need to:
// 1. Slice away the padded bits
// 2. Make the boolean array have the correct shape
// 3. Convert the array to a boolean array
unpackedBooleans = unpackedBooleans[py::slice(0, numBooleans, 1)];
unpackedBooleans = equalFunc(unpackedBooleans, 1);

std::vector<intptr_t> shape;
MlirType shapedType = mlirAttributeGetType(*this);
intptr_t rank = mlirShapedTypeGetRank(shapedType);
for (intptr_t i = 0; i < rank; ++i) {
shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
}
unpackedBooleans = reshapeFunc(unpackedBooleans, shape);

// Make sure the returned py::buffer_view claims ownership of the data in
// `pythonBuffer` so it remains valid when Python reads it
py::buffer pythonBuffer = unpackedBooleans.cast<py::buffer>();
return pythonBuffer.request();
}

template <typename Type>
py::buffer_info bufferInfo(MlirType shapedType,
const char *explicitFormat = nullptr) {
Expand Down
72 changes: 72 additions & 0 deletions mlir/test/python/ir/array_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,78 @@ def testGetDenseElementsF64():
print(np.array(attr))


### 1 bit/boolean integer arrays
# CHECK-LABEL: TEST: testGetDenseElementsI1Signless
@run
def testGetDenseElementsI1Signless():
with Context():
array = np.array([True], dtype=np.bool_)
attr = DenseElementsAttr.get(array)
# CHECK: dense<true> : tensor<1xi1>
print(attr)
# CHECK{LITERAL}: [ True]
print(np.array(attr))

array = np.array([[True, False, True], [True, True, False]], dtype=np.bool_)
attr = DenseElementsAttr.get(array)
# CHECK{LITERAL}: dense<[[true, false, true], [true, true, false]]> : tensor<2x3xi1>
print(attr)
# CHECK{LITERAL}: [[ True False True]
# CHECK{LITERAL}: [ True True False]]
print(np.array(attr))

array = np.array(
[[True, True, False, False], [True, False, True, False]], dtype=np.bool_
)
attr = DenseElementsAttr.get(array)
# CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false]]> : tensor<2x4xi1>
print(attr)
# CHECK{LITERAL}: [[ True True False False]
# CHECK{LITERAL}: [ True False True False]]
print(np.array(attr))

array = np.array(
[
[True, True, False, False],
[True, False, True, False],
[False, False, False, False],
[True, True, True, True],
[True, False, False, True],
],
dtype=np.bool_,
)
attr = DenseElementsAttr.get(array)
# CHECK{LITERAL}: dense<[[true, true, false, false], [true, false, true, false], [false, false, false, false], [true, true, true, true], [true, false, false, true]]> : tensor<5x4xi1>
print(attr)
# CHECK{LITERAL}: [[ True True False False]
# CHECK{LITERAL}: [ True False True False]
# CHECK{LITERAL}: [False False False False]
# CHECK{LITERAL}: [ True True True True]
# CHECK{LITERAL}: [ True False False True]]
print(np.array(attr))

array = np.array(
[
[True, True, False, False, True, True, False, False, False],
[False, False, False, True, False, True, True, False, True],
],
dtype=np.bool_,
)
attr = DenseElementsAttr.get(array)
# CHECK{LITERAL}: dense<[[true, true, false, false, true, true, false, false, false], [false, false, false, true, false, true, true, false, true]]> : tensor<2x9xi1>
print(attr)
# CHECK{LITERAL}: [[ True True False False True True False False False]
# CHECK{LITERAL}: [False False False True False True True False True]]
print(np.array(attr))

array = np.array([], dtype=np.bool_)
attr = DenseElementsAttr.get(array)
# CHECK: dense<> : tensor<0xi1>
print(attr)
# CHECK{LITERAL}: []
print(np.array(attr))


### 16 bit integer arrays
# CHECK-LABEL: TEST: testGetDenseElementsI16Signless
@run
Expand Down
Loading