Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 45 additions & 33 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,15 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc,
int64_t numContainerElemsToLoad,
Type emulatedElemTy,
Type containerElemTy) {
auto scale = containerElemTy.getIntOrFloatBitWidth() /
emulatedElemTy.getIntOrFloatBitWidth();
auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() /
emulatedElemTy.getIntOrFloatBitWidth();
auto newLoad = rewriter.create<vector::LoadOp>(
loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base,
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
return rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numContainerElemsToLoad * scale, emulatedElemTy),
loc,
VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem,
emulatedElemTy),
newLoad);
}

Expand Down Expand Up @@ -388,10 +390,11 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter,
"sliceNumElements * vector element size must be less than or equal to 8");
assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 &&
"vector element must be a valid sub-byte type");
auto scale = 8 / vectorElementType.getIntOrFloatBitWidth();
auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth();
auto emptyByteVector = rewriter.create<arith::ConstantOp>(
loc, VectorType::get({scale}, vectorElementType),
rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType)));
loc, VectorType::get({emulatedPerContainerElem}, vectorElementType),
rewriter.getZeroAttr(
VectorType::get({emulatedPerContainerElem}, vectorElementType)));
auto extracted = staticallyExtractSubvector(rewriter, loc, vector,
extractOffset, sliceNumElements);
return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector,
Expand Down Expand Up @@ -656,9 +659,9 @@ struct ConvertVectorMaskedStore final
"(bit-wise misalignment)");
}

int scale = containerBits / emulatedBits;
int emulatedPerContainerElem = containerBits / emulatedBits;
int origElements = op.getValueToStore().getType().getNumElements();
if (origElements % scale != 0)
if (origElements % emulatedPerContainerElem != 0)
return failure();

auto stridedMetadata =
Expand Down Expand Up @@ -707,12 +710,13 @@ struct ConvertVectorMaskedStore final
//
// FIXME: Make an example based on the comment above work (see #115460 for
// reproducer).
FailureOr<Operation *> newMask =
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale);
FailureOr<Operation *> newMask = getCompressedMaskOp(
rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem);
if (failed(newMask))
return failure();

auto numElements = (origElements + scale - 1) / scale;
auto numElements = (origElements + emulatedPerContainerElem - 1) /
emulatedPerContainerElem;
auto newType = VectorType::get(numElements, containerElemTy);
auto passThru = rewriter.create<arith::ConstantOp>(
loc, newType, rewriter.getZeroAttr(newType));
Expand All @@ -721,7 +725,8 @@ struct ConvertVectorMaskedStore final
loc, newType, adaptor.getBase(), linearizedIndices,
newMask.value()->getResult(0), passThru);

auto newBitCastType = VectorType::get(numElements * scale, emulatedElemTy);
auto newBitCastType =
VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);
Value valueToStore =
rewriter.create<vector::BitCastOp>(loc, newBitCastType, newLoad);
valueToStore = rewriter.create<arith::SelectOp>(
Expand Down Expand Up @@ -765,7 +770,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
op, "impossible to pack emulated elements into container elements "
"(bit-wise misalignment)");
}
int scale = containerBits / emulatedBits;
int emulatedPerContainerElem = containerBits / emulatedBits;

// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
Expand Down Expand Up @@ -797,7 +802,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
// compile time as they must be constants.

auto origElements = op.getVectorType().getNumElements();
bool isAlignedEmulation = origElements % scale == 0;
bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;

auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
Expand All @@ -818,9 +823,10 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
: getConstantIntValue(linearizedInfo.intraDataOffset);

// Always load enough elements which can cover the original elements.
int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
auto numElements =
llvm::divideCeil(maxintraDataOffset + origElements, scale);
int64_t maxintraDataOffset =
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
auto numElements = llvm::divideCeil(maxintraDataOffset + origElements,
emulatedPerContainerElem);
Value result =
emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
numElements, emulatedElemTy, containerElemTy);
Expand Down Expand Up @@ -870,7 +876,7 @@ struct ConvertVectorMaskedLoad final
op, "impossible to pack emulated elements into container elements "
"(bit-wise misalignment)");
}
int scale = containerBits / emulatedBits;
int emulatedPerContainerElem = containerBits / emulatedBits;

// Adjust the number of elements to load when emulating narrow types,
// and then cast back to the original type with vector.bitcast op.
Expand Down Expand Up @@ -916,7 +922,7 @@ struct ConvertVectorMaskedLoad final
// subvector at the proper offset after bit-casting.
auto origType = op.getVectorType();
auto origElements = origType.getNumElements();
bool isAlignedEmulation = origElements % scale == 0;
bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;

auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
Expand All @@ -935,18 +941,21 @@ struct ConvertVectorMaskedLoad final
? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);

int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
FailureOr<Operation *> newMask = getCompressedMaskOp(
rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset);
int64_t maxIntraDataOffset =
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
FailureOr<Operation *> newMask =
getCompressedMaskOp(rewriter, loc, op.getMask(), origElements,
emulatedPerContainerElem, maxIntraDataOffset);
if (failed(newMask))
return failure();

Value passthru = op.getPassThru();

auto numElements =
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
emulatedPerContainerElem);
auto loadType = VectorType::get(numElements, containerElemTy);
auto newBitcastType = VectorType::get(numElements * scale, emulatedElemTy);
auto newBitcastType =
VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy);

auto emptyVector = rewriter.create<arith::ConstantOp>(
loc, newBitcastType, rewriter.getZeroAttr(newBitcastType));
Expand All @@ -973,8 +982,8 @@ struct ConvertVectorMaskedLoad final
rewriter.create<vector::BitCastOp>(loc, newBitcastType, newLoad);

Value mask = op.getMask();
auto newSelectMaskType =
VectorType::get(numElements * scale, rewriter.getI1Type());
auto newSelectMaskType = VectorType::get(
numElements * emulatedPerContainerElem, rewriter.getI1Type());
// TODO: try to fold if op's mask is constant
auto emptyMask = rewriter.create<arith::ConstantOp>(
loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType));
Expand Down Expand Up @@ -1033,11 +1042,11 @@ struct ConvertVectorTransferRead final
op, "impossible to pack emulated elements into container elements "
"(bit-wise misalignment)");
}
int scale = containerBits / emulatedBits;
int emulatedPerContainerElem = containerBits / emulatedBits;

auto origElements = op.getVectorType().getNumElements();

bool isAlignedEmulation = origElements % scale == 0;
bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;

auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
adaptor.getPadding());
Expand All @@ -1060,17 +1069,20 @@ struct ConvertVectorTransferRead final
? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);

int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1);
auto numElements =
llvm::divideCeil(maxIntraDataOffset + origElements, scale);
int64_t maxIntraDataOffset =
foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1);
auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements,
emulatedPerContainerElem);

auto newRead = rewriter.create<vector::TransferReadOp>(
loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(),
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices),
newPadding);

auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements * scale, emulatedElemTy), newRead);
loc,
VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy),
newRead);

Value result = bitCast->getResult(0);
if (!foldedIntraVectorOffset) {
Expand Down