Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 101 additions & 20 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,86 @@ namespace {
// ConvertVectorStore
//===----------------------------------------------------------------------===//

// TODO: Document-me
// Emulate vector.store using a multi-byte container type
//
// The container type is obtained through Op adaptor and would normally be
// generated via `NarrowTypeEmulationConverter`.
//
// EXAMPLE 1
// (aligned store of i4, emulated using i8)
//
// vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
//
// is rewritten as:
//
// %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
// vector.store %src_bitcast, %dest_bitcast[%idx]
// : memref<16xi8>, vector<4xi8>
//
// EXAMPLE 2
// (unaligned store of i2, emulated using i8, non-atomic)
//
// vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
//
// The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
// is modelled using 3 bytes:
//
// Byte 0 Byte 1 Byte 2
// +----------+----------+----------+
// | oooooooo | ooooNNNN | NNoooooo |
// +----------+----------+----------+
//
// N - (N)ew entries (i.e. to be overwritten by vector.store)
// o - (o)ld entries (to be preserved)
//
// The following 2 RMW sequences will be generated:
//
// %init = arith.constant dense<0> : vector<4xi2>
//
// (RMW sequence for Byte 1)
// (Mask for 4 x i2 elements, i.e. a byte)
// %mask_1 = arith.constant dense<[false, false, true, true]>
// %src_slice_1 = vector.extract_strided_slice %src
// {offsets = [0], sizes = [2], strides = [1]}
// : vector<3xi2> to vector<2xi2>
// %init_with_slice_1 = vector.insert_strided_slice %src_slice_1, %init
// {offsets = [2], strides = [1]}
// : vector<2xi2> into vector<4xi2>
// %dest_byte_1 = vector.load %dest[%c1]
// %dest_byte_1_as_i2 = vector.bitcast %dest_byte_1
// : vector<1xi8> to vector<4xi2>
// %res_byte_1 = arith.select %mask_1, %init_with_slice_1, %dest_byte_1_as_i2
// %res_byte_1_as_i8 = vector.bitcast %res_byte_1
// vector.store %res_byte_1_as_i8, %dest[1]

// (RMW sequence for Byte 22)
// (Mask for 4 x i2 elements, i.e. a byte)
// %mask_2 = arith.constant dense<[true, false, false, false]>
// %src_slice_2 = vector.extract_strided_slice %src
// : {offsets = [2], sizes = [1], strides = [1]}
// : vector<3xi2> to vector<1xi2>
// %initi_with_slice_2 = vector.insert_strided_slice %src_slice_2, %init
// : {offsets = [0], strides = [1]}
// : vector<1xi2> into vector<4xi2>
// %dest_byte_2 = vector.load %dest[%c2]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can refer the reader to take a look at the corresponding test case, and we try to annotate/comment more precisely in the best case instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! I was wondering how to avoid this long comment and your suggestion is exactly what we should be doing! 🙏🏻

As this example is taken from "vector-emulate-narrow-type-unaligned-non-atomic.mlir", that's the test file that I've updated to help here. Please check the latest update.

Note, I've made quite a few changes:

  • Extended comments.
  • Fix DOWNCAST vs UPCAST.
  • Renamed some variables to avoid generic names (e.g. %arg0 -> %src, %0 -> %dest).
  • Added more CHECK-LINES, e.g. // CHECK-SAME: : vector<1xi8> to vector<4xi2> to make sure that the right casting is generated.
  • Followed formatting style from vectorize-convolution.mlir. IMHO it's a very "readable" style that's particularly handy for complex tests like these ones.

I appreciate that these are quite intrusive changes, but since it's meant as documentation, it felt like the right thing to do. But I am happy to adapt/revert if you feel that this is too much.

Thanks for reviewing!

// %dest_byte_2_as_i2 = vector.bitcast %dest_byte_2
// : vector<1xi8> to vector<4xi2>
// vector<4xi2> %res_byte_2 = arith.select %ask_2, %init_with_slice_2,
// %dest_byte_2_as_i2 %res_byte_1_as_i8 = vector.bitcast %rest_byte_2
// vector.store %res_byte_1_as_i8, %dest[2]
//
// NOTE: Unlike EXAMPLE 1, this case requires index re-calculation.
// NOTE: This example assumes that `disableAtomicRMW` was set.
//
// EXAMPLE 3
// (unaligned store of i2, emulated using i8, atomic)
//
// Similar to EXAMPLE 2, with the addition of
// * `memref.generic_atomic_rmw`,
// to guarantee atomicity. The actual output is skipped for brevity.
//
// NOTE: by default, all RMW sequences are atomic. Set `disableAtomicRMW` to
// `false` to generate non-atomic RMW sequences.
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
using OpConversionPattern::OpConversionPattern;

Expand Down Expand Up @@ -464,7 +543,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
op, "impossible to pack emulated elements into container elements "
"(bit-wise misalignment)");
}
int numSrcElemsPerDest = containerBits / emulatedBits;
int emulatedPerContainerElem = containerBits / emulatedBits;

// Adjust the number of elements to store when emulating narrow types.
// Here only the 1-D vector store is considered, and the N-D memref types
Expand All @@ -480,7 +559,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// vector<4xi8>

auto origElements = valueToStore.getType().getNumElements();
bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0;
// Note, per-element-alignment was already verified above.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so far we cannot support non-aligned cases for per element? i.e. we cannot support 7bit emulation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, all patterns require per-element alignment, yes. We should extract that condition somewhere to avoid repeating it.

On a related note, given how quickly this logic is growing, I feel that we should try to avoid emulating: i3, i5, i7. Unless that's really required. I am just worried that the potential cost of supporting that would be relatively high.

bool isFullyAligned = origElements % emulatedPerContainerElem == 0;

auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
Expand All @@ -496,7 +576,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
getAsOpFoldResult(adaptor.getIndices()));

std::optional<int64_t> foldedNumFrontPadElems =
isAlignedEmulation
isFullyAligned
? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);

Expand All @@ -516,10 +596,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// need unaligned emulation because the store address is aligned and the
// source is a whole byte.
bool emulationRequiresPartialStores =
!isAlignedEmulation || *foldedNumFrontPadElems != 0;
!isFullyAligned || *foldedNumFrontPadElems != 0;
if (!emulationRequiresPartialStores) {
// Basic case: storing full bytes.
auto numElements = origElements / numSrcElemsPerDest;
auto numElements = origElements / emulatedPerContainerElem;
auto bitCast = rewriter.create<vector::BitCastOp>(
loc, VectorType::get(numElements, containerElemTy),
op.getValueToStore());
Expand Down Expand Up @@ -567,7 +647,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {

// Build a mask used for rmw.
auto subWidthStoreMaskType =
VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());

auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;

Expand All @@ -576,10 +656,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// with the unaligned part so that the rest elements are aligned to width
// boundary.
auto frontSubWidthStoreElem =
(numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
(emulatedPerContainerElem - *foldedNumFrontPadElems) % emulatedPerContainerElem;
if (frontSubWidthStoreElem > 0) {
SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
origElements, true);
frontSubWidthStoreElem = origElements;
Expand All @@ -590,7 +670,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto frontMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));

currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
auto value =
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
frontSubWidthStoreElem, *foldedNumFrontPadElems);
Expand All @@ -614,8 +694,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
// After the previous step, the store address is aligned to the emulated
// width boundary.
int64_t fullWidthStoreSize =
(origElements - currentSourceIndex) / numSrcElemsPerDest;
int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
(origElements - currentSourceIndex) / emulatedPerContainerElem;
int64_t numNonFullWidthElements = fullWidthStoreSize * emulatedPerContainerElem;
if (fullWidthStoreSize > 0) {
auto fullWidthStorePart = staticallyExtractSubvector(
rewriter, loc, valueToStore, currentSourceIndex,
Expand All @@ -624,7 +704,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
auto originType = cast<VectorType>(fullWidthStorePart.getType());
auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
auto storeType = VectorType::get(
{originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
{originType.getNumElements() / emulatedPerContainerElem}, memrefElemType);
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
fullWidthStorePart);
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
Expand All @@ -646,7 +726,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
currentSourceIndex, remainingElements, 0);

// Generate back mask.
auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
auto maskValues = SmallVector<bool>(emulatedPerContainerElem, 0);
std::fill_n(maskValues.begin(), remainingElements, 1);
auto backMask = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
Expand Down Expand Up @@ -960,7 +1040,8 @@ struct ConvertVectorMaskedLoad final
// subvector at the proper offset after bit-casting.
auto origType = op.getVectorType();
auto origElements = origType.getNumElements();
bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
// Note, per-element-alignment was already verified above.
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;

auto stridedMetadata =
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
Expand All @@ -975,7 +1056,7 @@ struct ConvertVectorMaskedLoad final
getAsOpFoldResult(adaptor.getIndices()));

std::optional<int64_t> foldedIntraVectorOffset =
isAlignedEmulation
isFullyAligned
? 0
: getConstantIntValue(linearizedInfo.intraDataOffset);

Expand All @@ -1001,7 +1082,7 @@ struct ConvertVectorMaskedLoad final
passthru = dynamicallyInsertSubVector(
rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
origElements);
} else if (!isAlignedEmulation) {
} else if (!isFullyAligned) {
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
*foldedIntraVectorOffset);
}
Expand Down Expand Up @@ -1029,7 +1110,7 @@ struct ConvertVectorMaskedLoad final
mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
linearizedInfo.intraDataOffset,
origElements);
} else if (!isAlignedEmulation) {
} else if (!isFullyAligned) {
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
*foldedIntraVectorOffset);
}
Expand All @@ -1040,7 +1121,7 @@ struct ConvertVectorMaskedLoad final
result = dynamicallyExtractSubVector(
rewriter, loc, result, op.getPassThru(),
linearizedInfo.intraDataOffset, origElements);
} else if (!isAlignedEmulation) {
} else if (!isFullyAligned) {
result = staticallyExtractSubvector(
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
}
Expand Down