Skip to content
Merged
151 changes: 106 additions & 45 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
Expand All @@ -37,16 +38,17 @@ using namespace mlir;

/// Returns a compressed mask. The mask value is set only if any mask is present
/// in the scale range. E.g., if `scale` equals to 2, and `intraDataOffset`
/// equals to 2, the following mask:
/// equals to 1 (intraDataOffset strictly smaller than scale), the following
/// mask:
///
/// %mask = [1, 1, 1, 0, 0, 0]
/// %mask = [1, 1, 0, 0, 0, 0]
///
/// will first be padded with number of `intraDataOffset` zeros:
/// %mask = [0, 0, 1, 1, 1, 0, 0, 0]
/// %mask = [0, 1, 1, 0, 0, 0, 0, 0]
///
/// then it will return the following new compressed mask:
///
/// %mask = [0, 1, 1, 0]
/// %mask = [1, 1, 0, 0]
static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
Location loc, Value mask,
int origElements, int scale,
Expand Down Expand Up @@ -75,9 +77,6 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
shape.back() = numElements;
auto newMaskType = VectorType::get(shape, rewriter.getI1Type());
if (createMaskOp) {
// TODO: handle the case with non-zero intraDataOffset for CreateMaskOp.
if (intraDataOffset != 0)
return failure();
OperandRange maskOperands = createMaskOp.getOperands();
size_t numMaskOperands = maskOperands.size();
AffineExpr s0;
Expand Down Expand Up @@ -129,26 +128,79 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
return newMask;
}

static Value extractSubvectorFrom(RewriterBase &rewriter, Location loc,
VectorType extractType, Value vector,
int64_t frontOffset, int64_t subvecSize) {
/// Extracts 1-D subvector from a 1-D vector. It is a wrapper function for
/// emitting `vector.extract_strided_slice`.
static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc,
VectorType extractType, Value source,
int64_t frontOffset,
int64_t subvecSize) {
auto vectorType = cast<VectorType>(source.getType());
assert((vectorType.getRank() == 1 && extractType.getRank() == 1) &&
"expected 1-D source and destination types");
auto offsets = rewriter.getI64ArrayAttr({frontOffset});
auto sizes = rewriter.getI64ArrayAttr({subvecSize});
auto strides = rewriter.getI64ArrayAttr({1});
return rewriter
.create<vector::ExtractStridedSliceOp>(loc, extractType, vector, offsets,
.create<vector::ExtractStridedSliceOp>(loc, extractType, source, offsets,
sizes, strides)
->getResult(0);
}

static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
Value src, Value dest, int64_t offset) {
/// Inserts 1-D subvector into a 1-D vector by overwriting the elements starting
/// at `offset`. it is a wrapper function for emitting
/// `vector.insert_strided_slice`.
static Value staticallyInsertSubvector(OpBuilder &rewriter, Location loc,
Value src, Value dest, int64_t offset) {
auto srcType = cast<VectorType>(src.getType());
auto destType = cast<VectorType>(dest.getType());
assert(srcType.getRank() == 1 && destType.getRank() == 1 &&
"expected source and dest to be vector type");
auto offsets = rewriter.getI64ArrayAttr({offset});
auto strides = rewriter.getI64ArrayAttr({1});
return rewriter.create<vector::InsertStridedSliceOp>(loc, dest.getType(), src,
dest, offsets, strides);
}

/// Extracts a 1-D subvector from a 1-D `source` vector, with index at `offset`
/// and size `numElementsToExtract`, and inserts into the `dest` vector. This
/// function emits multiple `vector.extract` and `vector.insert` ops, so only
/// use it when `offset` cannot be folded into a constant value.
static Value dynamicallyExtractSubVector(OpBuilder &rewriter, Location loc,
TypedValue<VectorType> source,
Value dest, OpFoldResult offset,
int64_t numElementsToExtract) {
for (int i = 0; i < numElementsToExtract; ++i) {
Value extractLoc =
(i == 0) ? offset.dyn_cast<Value>()
: rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), offset.dyn_cast<Value>(),
rewriter.create<arith::ConstantIndexOp>(loc, i));
auto extractOp =
rewriter.create<vector::ExtractOp>(loc, source, extractLoc);
dest = rewriter.create<vector::InsertOp>(loc, extractOp, dest, i);
}
return dest;
}

/// Returns the op sequence for an emulated sub-byte data type vector load.
/// specifically, use `emulatedElemType` for loading a vector of `origElemType`.
/// The load location is given by `base` and `linearizedIndices`, and the
/// load size is given by `numEmulatedElementsToLoad`.
static TypedValue<VectorType>
emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base,
OpFoldResult linearizedIndices,
int64_t numEmultedElementsToLoad, Type origElemType,
Type emulatedElemType) {
auto scale = emulatedElemType.getIntOrFloatBitWidth() /
origElemType.getIntOrFloatBitWidth();
auto newLoad = rewriter.create<vector::LoadOp>(
loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base,
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
return rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType),
newLoad);
};

namespace {

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -380,25 +432,27 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;

if (!foldedIntraVectorOffset) {
// unimplemented case for dynamic intra vector offset
return failure();
}

// Always load enough elements which can cover the original elements.
int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
auto numElements =
llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
auto newLoad = rewriter.create<vector::LoadOp>(
loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));

Value result = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements * scale, oldElementType), newLoad);

if (isUnalignedEmulation) {
result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
llvm::divideCeil(maxintraDataOffset + origElements, scale);
Value result =
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
numElements, oldElementType, newElementType);

if (foldedIntraVectorOffset) {
if (isUnalignedEmulation) {
result =
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
}
} else {
auto resultVector = rewriter.create<arith::ConstantOp>(
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
result = dynamicallyExtractSubVector(
rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
linearizedInfo.intraDataOffset, origElements);
}

rewriter.replaceOp(op, result);
return success();
}
Expand Down Expand Up @@ -513,8 +567,8 @@ struct ConvertVectorMaskedLoad final
// create an empty vector of the new type
auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
passthru = insertSubvectorInto(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
}
auto newPassThru =
rewriter.create<vector::BitCastOp>(loc, loadType, passthru);
Expand All @@ -537,16 +591,17 @@ struct ConvertVectorMaskedLoad final
// TODO: can fold if op's mask is constant
auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
mask = insertSubvectorInto(rewriter, loc, op.getMask(), emptyVector,
*foldedIntraVectorOffset);
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyVector,
*foldedIntraVectorOffset);
}

Value result =
rewriter.create<arith::SelectOp>(loc, mask, bitCast, passthru);

if (isUnalignedEmulation) {
result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
result =
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
}
rewriter.replaceOp(op, result);

Expand Down Expand Up @@ -604,13 +659,10 @@ struct ConvertVectorTransferRead final
? getConstantIntValue(linearizedInfo.intraDataOffset)
: 0;

if (!foldedIntraVectorOffset) {
// unimplemented case for dynamic inra-vector offset
return failure();
}

auto maxIntraVectorOffset =
foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
auto numElements =
llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
llvm::divideCeil(maxIntraVectorOffset + origElements, scale);

auto newRead = rewriter.create<vector::TransferReadOp>(
loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
Expand All @@ -621,9 +673,18 @@ struct ConvertVectorTransferRead final
loc, VectorType::get(numElements * scale, oldElementType), newRead);

Value result = bitCast->getResult(0);
if (isUnalignedEmulation) {
result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
if (foldedIntraVectorOffset) {
if (isUnalignedEmulation) {
result =
staticallyExtractSubvector(rewriter, loc, op.getType(), result,
*foldedIntraVectorOffset, origElements);
}
} else {
auto zeros = rewriter.create<arith::ConstantOp>(
loc, op.getType(), rewriter.getZeroAttr(op.getType()));
result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros,
linearizedInfo.intraDataOffset,
origElements);
}
rewriter.replaceOp(op, result);

Expand Down
Loading