Skip to content

Commit ffd3552

Browse files
committed
[mlir][vector] Document ConvertVectorStore + unify var names (nfc)
1. Documents `ConvertVectorStore`. 2. As a follow-on for llvm#123527, renames `isAlignedEmulation` to `isFullyAligned` and `numSrcElemsPerDest` to `emulatedPerContainerElem`.
1 parent b1a267e commit ffd3552

File tree

1 file changed

+101
-20
lines changed

1 file changed

+101
-20
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 101 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,86 @@ namespace {
432432
// ConvertVectorStore
433433
//===----------------------------------------------------------------------===//
434434

435-
// TODO: Document-me
435+
// Emulate vector.store using a multi-byte container type
436+
//
437+
// The container type is obtained through Op adaptor and would normally be
438+
// generated via `NarrowTypeEmulationConverter`.
439+
//
440+
// EXAMPLE 1
441+
// (aligned store of i4, emulated using i8)
442+
//
443+
// vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
444+
//
445+
// is rewritten as:
446+
//
447+
// %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
448+
// vector.store %src_bitcast, %dest_bitcast[%idx]
449+
// : memref<16xi8>, vector<4xi8>
450+
//
451+
// EXAMPLE 2
452+
// (unaligned store of i2, emulated using i8, non-atomic)
453+
//
454+
// vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
455+
//
456+
// The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
457+
// is modelled using 3 bytes:
458+
//
459+
// Byte 0 Byte 1 Byte 2
460+
// +----------+----------+----------+
461+
// | oooooooo | ooooNNNN | NNoooooo |
462+
// +----------+----------+----------+
463+
//
464+
// N - (N)ew entries (i.e. to be overwritten by vector.store)
465+
// o - (o)ld entries (to be preserved)
466+
//
467+
// The following 2 RMW sequences will be generated:
468+
//
469+
// %init = arith.constant dense<0> : vector<4xi2>
470+
//
471+
// (RMW sequence for Byte 1)
472+
// (Mask for 4 x i2 elements, i.e. a byte)
473+
// %mask_1 = arith.constant dense<[false, false, true, true]>
474+
// %src_slice_1 = vector.extract_strided_slice %src
475+
// {offsets = [0], sizes = [2], strides = [1]}
476+
// : vector<3xi2> to vector<2xi2>
477+
// %init_with_slice_1 = vector.insert_strided_slice %src_slice_1, %init
478+
// {offsets = [2], strides = [1]}
479+
// : vector<2xi2> into vector<4xi2>
480+
// %dest_byte_1 = vector.load %dest[%c1]
481+
// %dest_byte_1_as_i2 = vector.bitcast %dest_byte_1
482+
// : vector<1xi8> to vector<4xi2>
483+
// %res_byte_1 = arith.select %mask_1, %init_with_slice_1, %dest_byte_1_as_i2
484+
// %res_byte_1_as_i8 = vector.bitcast %res_byte_1
485+
// vector.store %res_byte_1_as_i8, %dest[1]
486+
487+
// (RMW sequence for Byte 22)
488+
// (Mask for 4 x i2 elements, i.e. a byte)
489+
// %mask_2 = arith.constant dense<[true, false, false, false]>
490+
// %src_slice_2 = vector.extract_strided_slice %src
491+
// : {offsets = [2], sizes = [1], strides = [1]}
492+
// : vector<3xi2> to vector<1xi2>
493+
// %initi_with_slice_2 = vector.insert_strided_slice %src_slice_2, %init
494+
// : {offsets = [0], strides = [1]}
495+
// : vector<1xi2> into vector<4xi2>
496+
// %dest_byte_2 = vector.load %dest[%c2]
497+
// %dest_byte_2_as_i2 = vector.bitcast %dest_byte_2
498+
// : vector<1xi8> to vector<4xi2>
499+
// vector<4xi2> %res_byte_2 = arith.select %ask_2, %init_with_slice_2,
500+
// %dest_byte_2_as_i2 %res_byte_1_as_i8 = vector.bitcast %rest_byte_2
501+
// vector.store %res_byte_1_as_i8, %dest[2]
502+
//
503+
// NOTE: Unlike EXAMPLE 1, this case requires index re-calculation.
504+
// NOTE: This example assumes that `disableAtomicRMW` was set.
505+
//
506+
// EXAMPLE 3
507+
// (unaligned store of i2, emulated using i8, atomic)
508+
//
509+
// Similar to EXAMPLE 2, with the addition of
510+
// * `memref.generic_atomic_rmw`,
511+
// to guarantee atomicity. The actual output is skipped for brevity.
512+
//
513+
// NOTE: by default, all RMW sequences are atomic. Set `disableAtomicRMW` to
514+
// `false` to generate non-atomic RMW sequences.
436515
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
437516
using OpConversionPattern::OpConversionPattern;
438517

@@ -464,7 +543,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
464543
op, "impossible to pack emulated elements into container elements "
465544
"(bit-wise misalignment)");
466545
}
467-
int numSrcElemsPerDest = containerBits / emulatedBits;
546+
int emulatedPerContainerElem = containerBits / emulatedBits;
468547

469548
// Adjust the number of elements to store when emulating narrow types.
470549
// Here only the 1-D vector store is considered, and the N-D memref types
@@ -480,7 +559,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
480559
// vector<4xi8>
481560

482561
auto origElements = valueToStore.getType().getNumElements();
483-
bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0;
562+
// Note, per-element-alignment was already verified above.
563+
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
484564

485565
auto stridedMetadata =
486566
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -496,7 +576,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
496576
getAsOpFoldResult(adaptor.getIndices()));
497577

498578
std::optional<int64_t> foldedNumFrontPadElems =
499-
isAlignedEmulation
579+
isFullyAligned
500580
? 0
501581
: getConstantIntValue(linearizedInfo.intraDataOffset);
502582

@@ -516,10 +596,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
516596
// need unaligned emulation because the store address is aligned and the
517597
// source is a whole byte.
518598
bool emulationRequiresPartialStores =
519-
!isAlignedEmulation || *foldedNumFrontPadElems != 0;
599+
!isFullyAligned || *foldedNumFrontPadElems != 0;
520600
if (!emulationRequiresPartialStores) {
521601
// Basic case: storing full bytes.
522-
auto numElements = origElements / numSrcElemsPerDest;
602+
auto numElements = origElements / emulatedPerContainerElem;
523603
auto bitCast = rewriter.create<vector::BitCastOp>(
524604
loc, VectorType::get(numElements, containerElemTy),
525605
op.getValueToStore());
@@ -567,7 +647,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
567647

568648
// Build a mask used for rmw.
569649
auto subWidthStoreMaskType =
570-
VectorType::get({numSrcElemsPerDest}, rewriter.getI1Type());
650+
VectorType::get({emulatedPerContainerElem}, rewriter.getI1Type());
571651

572652
auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
573653

@@ -576,10 +656,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
576656
// with the unaligned part so that the rest elements are aligned to width
577657
// boundary.
578658
auto frontSubWidthStoreElem =
579-
(numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
659+
(emulatedPerContainerElem - *foldedNumFrontPadElems) % emulatedPerContainerElem;
580660
if (frontSubWidthStoreElem > 0) {
581-
SmallVector<bool> frontMaskValues(numSrcElemsPerDest, false);
582-
if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest) {
661+
SmallVector<bool> frontMaskValues(emulatedPerContainerElem, false);
662+
if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem) {
583663
std::fill_n(frontMaskValues.begin() + *foldedNumFrontPadElems,
584664
origElements, true);
585665
frontSubWidthStoreElem = origElements;
@@ -590,7 +670,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
590670
auto frontMask = rewriter.create<arith::ConstantOp>(
591671
loc, DenseElementsAttr::get(subWidthStoreMaskType, frontMaskValues));
592672

593-
currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
673+
currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
594674
auto value =
595675
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
596676
frontSubWidthStoreElem, *foldedNumFrontPadElems);
@@ -614,8 +694,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
614694
// After the previous step, the store address is aligned to the emulated
615695
// width boundary.
616696
int64_t fullWidthStoreSize =
617-
(origElements - currentSourceIndex) / numSrcElemsPerDest;
618-
int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
697+
(origElements - currentSourceIndex) / emulatedPerContainerElem;
698+
int64_t numNonFullWidthElements = fullWidthStoreSize * emulatedPerContainerElem;
619699
if (fullWidthStoreSize > 0) {
620700
auto fullWidthStorePart = staticallyExtractSubvector(
621701
rewriter, loc, valueToStore, currentSourceIndex,
@@ -624,7 +704,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
624704
auto originType = cast<VectorType>(fullWidthStorePart.getType());
625705
auto memrefElemType = getElementTypeOrSelf(memrefBase.getType());
626706
auto storeType = VectorType::get(
627-
{originType.getNumElements() / numSrcElemsPerDest}, memrefElemType);
707+
{originType.getNumElements() / emulatedPerContainerElem}, memrefElemType);
628708
auto bitCast = rewriter.create<vector::BitCastOp>(loc, storeType,
629709
fullWidthStorePart);
630710
rewriter.create<vector::StoreOp>(loc, bitCast.getResult(), memrefBase,
@@ -646,7 +726,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
646726
currentSourceIndex, remainingElements, 0);
647727

648728
// Generate back mask.
649-
auto maskValues = SmallVector<bool>(numSrcElemsPerDest, 0);
729+
auto maskValues = SmallVector<bool>(emulatedPerContainerElem, 0);
650730
std::fill_n(maskValues.begin(), remainingElements, 1);
651731
auto backMask = rewriter.create<arith::ConstantOp>(
652732
loc, DenseElementsAttr::get(subWidthStoreMaskType, maskValues));
@@ -960,7 +1040,8 @@ struct ConvertVectorMaskedLoad final
9601040
// subvector at the proper offset after bit-casting.
9611041
auto origType = op.getVectorType();
9621042
auto origElements = origType.getNumElements();
963-
bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0;
1043+
// Note, per-element-alignment was already verified above.
1044+
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
9641045

9651046
auto stridedMetadata =
9661047
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
@@ -975,7 +1056,7 @@ struct ConvertVectorMaskedLoad final
9751056
getAsOpFoldResult(adaptor.getIndices()));
9761057

9771058
std::optional<int64_t> foldedIntraVectorOffset =
978-
isAlignedEmulation
1059+
isFullyAligned
9791060
? 0
9801061
: getConstantIntValue(linearizedInfo.intraDataOffset);
9811062

@@ -1001,7 +1082,7 @@ struct ConvertVectorMaskedLoad final
10011082
passthru = dynamicallyInsertSubVector(
10021083
rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset,
10031084
origElements);
1004-
} else if (!isAlignedEmulation) {
1085+
} else if (!isFullyAligned) {
10051086
passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector,
10061087
*foldedIntraVectorOffset);
10071088
}
@@ -1029,7 +1110,7 @@ struct ConvertVectorMaskedLoad final
10291110
mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask,
10301111
linearizedInfo.intraDataOffset,
10311112
origElements);
1032-
} else if (!isAlignedEmulation) {
1113+
} else if (!isFullyAligned) {
10331114
mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask,
10341115
*foldedIntraVectorOffset);
10351116
}
@@ -1040,7 +1121,7 @@ struct ConvertVectorMaskedLoad final
10401121
result = dynamicallyExtractSubVector(
10411122
rewriter, loc, result, op.getPassThru(),
10421123
linearizedInfo.intraDataOffset, origElements);
1043-
} else if (!isAlignedEmulation) {
1124+
} else if (!isFullyAligned) {
10441125
result = staticallyExtractSubvector(
10451126
rewriter, loc, result, *foldedIntraVectorOffset, origElements);
10461127
}

0 commit comments

Comments
 (0)