@@ -432,7 +432,86 @@ namespace {
432432// ConvertVectorStore
433433// ===----------------------------------------------------------------------===//
434434
435- // TODO: Document-me
435+ // Emulate vector.store using a multi-byte container type
436+ //
437+ // The container type is obtained through Op adaptor and would normally be
438+ // generated via `NarrowTypeEmulationConverter`.
439+ //
440+ // EXAMPLE 1
441+ // (aligned store of i4, emulated using i8)
442+ //
443+ // vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
444+ //
445+ // is rewritten as:
446+ //
447+ // %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
448+ // vector.store %src_bitcast, %dest_bitcast[%idx]
449+ // : memref<16xi8>, vector<4xi8>
450+ //
451+ // EXAMPLE 2
452+ // (unaligned store of i2, emulated using i8, non-atomic)
453+ //
454+ // vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
455+ //
456+ // The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
457+ // is modelled using 3 bytes:
458+ //
459+ // Byte 0 Byte 1 Byte 2
460+ // +----------+----------+----------+
461+ // | oooooooo | ooooNNNN | NNoooooo |
462+ // +----------+----------+----------+
463+ //
464+ // N - (N)ew entries (i.e. to be overwritten by vector.store)
465+ // o - (o)ld entries (to be preserved)
466+ //
467+ // The following 2 RMW sequences will be generated:
468+ //
469+ // %init = arith.constant dense<0> : vector<4xi2>
470+ //
471+ // (RMW sequence for Byte 1)
472+ // (Mask for 4 x i2 elements, i.e. a byte)
473+ // %mask_1 = arith.constant dense<[false, false, true, true]>
474+ // %src_slice_1 = vector.extract_strided_slice %src
475+ // {offsets = [0], sizes = [2], strides = [1]}
476+ // : vector<3xi2> to vector<2xi2>
477+ // %init_with_slice_1 = vector.insert_strided_slice %src_slice_1, %init
478+ // {offsets = [2], strides = [1]}
479+ // : vector<2xi2> into vector<4xi2>
480+ // %dest_byte_1 = vector.load %dest[%c1]
481+ // %dest_byte_1_as_i2 = vector.bitcast %dest_byte_1
482+ // : vector<1xi8> to vector<4xi2>
483+ // %res_byte_1 = arith.select %mask_1, %init_with_slice_1, %dest_byte_1_as_i2
484+ // %res_byte_1_as_i8 = vector.bitcast %res_byte_1
485+ // vector.store %res_byte_1_as_i8, %dest[1]
486+
487+ // (RMW sequence for Byte 22)
488+ // (Mask for 4 x i2 elements, i.e. a byte)
489+ // %mask_2 = arith.constant dense<[true, false, false, false]>
490+ // %src_slice_2 = vector.extract_strided_slice %src
491+ // : {offsets = [2], sizes = [1], strides = [1]}
492+ // : vector<3xi2> to vector<1xi2>
493+ // %initi_with_slice_2 = vector.insert_strided_slice %src_slice_2, %init
494+ // : {offsets = [0], strides = [1]}
495+ // : vector<1xi2> into vector<4xi2>
496+ // %dest_byte_2 = vector.load %dest[%c2]
497+ // %dest_byte_2_as_i2 = vector.bitcast %dest_byte_2
498+ // : vector<1xi8> to vector<4xi2>
499+ // vector<4xi2> %res_byte_2 = arith.select %ask_2, %init_with_slice_2,
500+ // %dest_byte_2_as_i2 %res_byte_1_as_i8 = vector.bitcast %rest_byte_2
501+ // vector.store %res_byte_1_as_i8, %dest[2]
502+ //
503+ // NOTE: Unlike EXAMPLE 1, this case requires index re-calculation.
504+ // NOTE: This example assumes that `disableAtomicRMW` was set.
505+ //
506+ // EXAMPLE 3
507+ // (unaligned store of i2, emulated using i8, atomic)
508+ //
509+ // Similar to EXAMPLE 2, with the addition of
510+ // * `memref.generic_atomic_rmw`,
511+ // to guarantee atomicity. The actual output is skipped for brevity.
512+ //
513+ // NOTE: by default, all RMW sequences are atomic. Set `disableAtomicRMW` to
514+ // `false` to generate non-atomic RMW sequences.
436515struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
437516 using OpConversionPattern::OpConversionPattern;
438517
@@ -464,7 +543,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
464543 op, " impossible to pack emulated elements into container elements "
465544 " (bit-wise misalignment)" );
466545 }
467- int numSrcElemsPerDest = containerBits / emulatedBits;
546+ int emulatedPerContainerElem = containerBits / emulatedBits;
468547
469548 // Adjust the number of elements to store when emulating narrow types.
470549 // Here only the 1-D vector store is considered, and the N-D memref types
@@ -480,7 +559,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
480559 // vector<4xi8>
481560
482561 auto origElements = valueToStore.getType ().getNumElements ();
483- bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0 ;
562+ // Note, per-element-alignment was already verified above.
563+ bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
484564
485565 auto stridedMetadata =
486566 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -496,7 +576,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
496576 getAsOpFoldResult (adaptor.getIndices ()));
497577
498578 std::optional<int64_t > foldedNumFrontPadElems =
499- isAlignedEmulation
579+ isFullyAligned
500580 ? 0
501581 : getConstantIntValue (linearizedInfo.intraDataOffset );
502582
@@ -516,10 +596,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
516596 // need unaligned emulation because the store address is aligned and the
517597 // source is a whole byte.
518598 bool emulationRequiresPartialStores =
519- !isAlignedEmulation || *foldedNumFrontPadElems != 0 ;
599+ !isFullyAligned || *foldedNumFrontPadElems != 0 ;
520600 if (!emulationRequiresPartialStores) {
521601 // Basic case: storing full bytes.
522- auto numElements = origElements / numSrcElemsPerDest ;
602+ auto numElements = origElements / emulatedPerContainerElem ;
523603 auto bitCast = rewriter.create <vector::BitCastOp>(
524604 loc, VectorType::get (numElements, containerElemTy),
525605 op.getValueToStore ());
@@ -567,7 +647,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
567647
568648 // Build a mask used for rmw.
569649 auto subWidthStoreMaskType =
570- VectorType::get ({numSrcElemsPerDest }, rewriter.getI1Type ());
650+ VectorType::get ({emulatedPerContainerElem }, rewriter.getI1Type ());
571651
572652 auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
573653
@@ -576,10 +656,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
576656 // with the unaligned part so that the rest elements are aligned to width
577657 // boundary.
578658 auto frontSubWidthStoreElem =
579- (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest ;
659+ (emulatedPerContainerElem - *foldedNumFrontPadElems) % emulatedPerContainerElem ;
580660 if (frontSubWidthStoreElem > 0 ) {
581- SmallVector<bool > frontMaskValues (numSrcElemsPerDest , false );
582- if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest ) {
661+ SmallVector<bool > frontMaskValues (emulatedPerContainerElem , false );
662+ if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem ) {
583663 std::fill_n (frontMaskValues.begin () + *foldedNumFrontPadElems,
584664 origElements, true );
585665 frontSubWidthStoreElem = origElements;
@@ -590,7 +670,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
590670 auto frontMask = rewriter.create <arith::ConstantOp>(
591671 loc, DenseElementsAttr::get (subWidthStoreMaskType, frontMaskValues));
592672
593- currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
673+ currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
594674 auto value =
595675 extractSliceIntoByte (rewriter, loc, valueToStore, 0 ,
596676 frontSubWidthStoreElem, *foldedNumFrontPadElems);
@@ -614,8 +694,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
614694 // After the previous step, the store address is aligned to the emulated
615695 // width boundary.
616696 int64_t fullWidthStoreSize =
617- (origElements - currentSourceIndex) / numSrcElemsPerDest ;
618- int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest ;
697+ (origElements - currentSourceIndex) / emulatedPerContainerElem ;
698+ int64_t numNonFullWidthElements = fullWidthStoreSize * emulatedPerContainerElem ;
619699 if (fullWidthStoreSize > 0 ) {
620700 auto fullWidthStorePart = staticallyExtractSubvector (
621701 rewriter, loc, valueToStore, currentSourceIndex,
@@ -624,7 +704,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
624704 auto originType = cast<VectorType>(fullWidthStorePart.getType ());
625705 auto memrefElemType = getElementTypeOrSelf (memrefBase.getType ());
626706 auto storeType = VectorType::get (
627- {originType.getNumElements () / numSrcElemsPerDest }, memrefElemType);
707+ {originType.getNumElements () / emulatedPerContainerElem }, memrefElemType);
628708 auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType,
629709 fullWidthStorePart);
630710 rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), memrefBase,
@@ -646,7 +726,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
646726 currentSourceIndex, remainingElements, 0 );
647727
648728 // Generate back mask.
649- auto maskValues = SmallVector<bool >(numSrcElemsPerDest , 0 );
729+ auto maskValues = SmallVector<bool >(emulatedPerContainerElem , 0 );
650730 std::fill_n (maskValues.begin (), remainingElements, 1 );
651731 auto backMask = rewriter.create <arith::ConstantOp>(
652732 loc, DenseElementsAttr::get (subWidthStoreMaskType, maskValues));
@@ -960,7 +1040,8 @@ struct ConvertVectorMaskedLoad final
9601040 // subvector at the proper offset after bit-casting.
9611041 auto origType = op.getVectorType ();
9621042 auto origElements = origType.getNumElements ();
963- bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0 ;
1043+ // Note, per-element-alignment was already verified above.
1044+ bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
9641045
9651046 auto stridedMetadata =
9661047 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -975,7 +1056,7 @@ struct ConvertVectorMaskedLoad final
9751056 getAsOpFoldResult (adaptor.getIndices ()));
9761057
9771058 std::optional<int64_t > foldedIntraVectorOffset =
978- isAlignedEmulation
1059+ isFullyAligned
9791060 ? 0
9801061 : getConstantIntValue (linearizedInfo.intraDataOffset );
9811062
@@ -1001,7 +1082,7 @@ struct ConvertVectorMaskedLoad final
10011082 passthru = dynamicallyInsertSubVector (
10021083 rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset ,
10031084 origElements);
1004- } else if (!isAlignedEmulation ) {
1085+ } else if (!isFullyAligned ) {
10051086 passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
10061087 *foldedIntraVectorOffset);
10071088 }
@@ -1029,7 +1110,7 @@ struct ConvertVectorMaskedLoad final
10291110 mask = dynamicallyInsertSubVector (rewriter, loc, mask, emptyMask,
10301111 linearizedInfo.intraDataOffset ,
10311112 origElements);
1032- } else if (!isAlignedEmulation ) {
1113+ } else if (!isFullyAligned ) {
10331114 mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyMask,
10341115 *foldedIntraVectorOffset);
10351116 }
@@ -1040,7 +1121,7 @@ struct ConvertVectorMaskedLoad final
10401121 result = dynamicallyExtractSubVector (
10411122 rewriter, loc, result, op.getPassThru (),
10421123 linearizedInfo.intraDataOffset , origElements);
1043- } else if (!isAlignedEmulation ) {
1124+ } else if (!isFullyAligned ) {
10441125 result = staticallyExtractSubvector (
10451126 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
10461127 }
0 commit comments