@@ -432,7 +432,45 @@ namespace {
432432// ConvertVectorStore
433433// ===----------------------------------------------------------------------===//
434434
435- // TODO: Document-me
435+ // Emulate `vector.store` using a multi-byte container type.
436+ //
437+ // The container type is obtained through Op adaptor and would normally be
438+ // generated via `NarrowTypeEmulationConverter`.
439+ //
440+ // EXAMPLE 1
441+ // (aligned store of i4, emulated using i8 as the container type)
442+ //
443+ // vector.store %src, %dest[%idx_1, %idx_2] : memref<4x8xi4>, vector<8xi4>
444+ //
445+ // is rewritten as:
446+ //
447+ // %src_bitcast = vector.bitcast %src : vector<8xi4> to vector<4xi8>
448+ // vector.store %src_bitcast, %dest_bitcast[%idx]
449+ // : memref<16xi8>, vector<4xi8>
450+ //
451+ // EXAMPLE 2
452+ // (unaligned store of i2, emulated using i8 as the container type)
453+ //
454+ // vector.store %src, %dest[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
455+ //
456+ // The i2 store is emulated through 2 x RMW sequences. The destination i2 memref
457+ // is modelled using 3 bytes:
458+ //
459+ // Byte 0 Byte 1 Byte 2
460+ // +----------+----------+----------+
461+ // | oooooooo | ooooNNNN | NNoooooo |
462+ // +----------+----------+----------+
463+ //
464+ // N - (N)ew entries (i.e. to be overwritten by vector.store)
465+ // o - (o)ld entries (to be preserved)
466+ //
467+ // For the generated output in the non-atomic case, see:
468+ // * @vector_store_i2_const_index_two_partial_stores`
469+ // in:
470+ // * "vector-emulate-narrow-type-unaligned-non-atomic.mlir".
471+ //
472+ // NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
473+ // `false` to generate non-atomic RMW sequences.
436474struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
437475 using OpConversionPattern::OpConversionPattern;
438476
@@ -464,7 +502,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
464502 op, " impossible to pack emulated elements into container elements "
465503 " (bit-wise misalignment)" );
466504 }
467- int numSrcElemsPerDest = containerBits / emulatedBits;
505+ int emulatedPerContainerElem = containerBits / emulatedBits;
468506
469507 // Adjust the number of elements to store when emulating narrow types.
470508 // Here only the 1-D vector store is considered, and the N-D memref types
@@ -480,7 +518,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
480518 // vector<4xi8>
481519
482520 auto origElements = valueToStore.getType ().getNumElements ();
483- bool isAlignedEmulation = origElements % numSrcElemsPerDest == 0 ;
521+ // Note, per-element-alignment was already verified above.
522+ bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
484523
485524 auto stridedMetadata =
486525 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -496,9 +535,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
496535 getAsOpFoldResult (adaptor.getIndices ()));
497536
498537 std::optional<int64_t > foldedNumFrontPadElems =
499- isAlignedEmulation
500- ? 0
501- : getConstantIntValue (linearizedInfo.intraDataOffset );
538+ isFullyAligned ? 0
539+ : getConstantIntValue (linearizedInfo.intraDataOffset );
502540
503541 if (!foldedNumFrontPadElems) {
504542 return rewriter.notifyMatchFailure (
@@ -516,10 +554,10 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
516554 // need unaligned emulation because the store address is aligned and the
517555 // source is a whole byte.
518556 bool emulationRequiresPartialStores =
519- !isAlignedEmulation || *foldedNumFrontPadElems != 0 ;
557+ !isFullyAligned || *foldedNumFrontPadElems != 0 ;
520558 if (!emulationRequiresPartialStores) {
521559 // Basic case: storing full bytes.
522- auto numElements = origElements / numSrcElemsPerDest ;
560+ auto numElements = origElements / emulatedPerContainerElem ;
523561 auto bitCast = rewriter.create <vector::BitCastOp>(
524562 loc, VectorType::get (numElements, containerElemTy),
525563 op.getValueToStore ());
@@ -567,7 +605,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
567605
568606 // Build a mask used for rmw.
569607 auto subWidthStoreMaskType =
570- VectorType::get ({numSrcElemsPerDest }, rewriter.getI1Type ());
608+ VectorType::get ({emulatedPerContainerElem }, rewriter.getI1Type ());
571609
572610 auto storeFunc = disableAtomicRMW ? nonAtomicRMW : atomicRMW;
573611
@@ -576,10 +614,11 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
576614 // with the unaligned part so that the rest elements are aligned to width
577615 // boundary.
578616 auto frontSubWidthStoreElem =
579- (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
617+ (emulatedPerContainerElem - *foldedNumFrontPadElems) %
618+ emulatedPerContainerElem;
580619 if (frontSubWidthStoreElem > 0 ) {
581- SmallVector<bool > frontMaskValues (numSrcElemsPerDest , false );
582- if (*foldedNumFrontPadElems + origElements < numSrcElemsPerDest ) {
620+ SmallVector<bool > frontMaskValues (emulatedPerContainerElem , false );
621+ if (*foldedNumFrontPadElems + origElements < emulatedPerContainerElem ) {
583622 std::fill_n (frontMaskValues.begin () + *foldedNumFrontPadElems,
584623 origElements, true );
585624 frontSubWidthStoreElem = origElements;
@@ -590,7 +629,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
590629 auto frontMask = rewriter.create <arith::ConstantOp>(
591630 loc, DenseElementsAttr::get (subWidthStoreMaskType, frontMaskValues));
592631
593- currentSourceIndex = numSrcElemsPerDest - (*foldedNumFrontPadElems);
632+ currentSourceIndex = emulatedPerContainerElem - (*foldedNumFrontPadElems);
594633 auto value =
595634 extractSliceIntoByte (rewriter, loc, valueToStore, 0 ,
596635 frontSubWidthStoreElem, *foldedNumFrontPadElems);
@@ -614,8 +653,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
614653 // After the previous step, the store address is aligned to the emulated
615654 // width boundary.
616655 int64_t fullWidthStoreSize =
617- (origElements - currentSourceIndex) / numSrcElemsPerDest;
618- int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
656+ (origElements - currentSourceIndex) / emulatedPerContainerElem;
657+ int64_t numNonFullWidthElements =
658+ fullWidthStoreSize * emulatedPerContainerElem;
619659 if (fullWidthStoreSize > 0 ) {
620660 auto fullWidthStorePart = staticallyExtractSubvector (
621661 rewriter, loc, valueToStore, currentSourceIndex,
@@ -624,7 +664,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
624664 auto originType = cast<VectorType>(fullWidthStorePart.getType ());
625665 auto memrefElemType = getElementTypeOrSelf (memrefBase.getType ());
626666 auto storeType = VectorType::get (
627- {originType.getNumElements () / numSrcElemsPerDest}, memrefElemType);
667+ {originType.getNumElements () / emulatedPerContainerElem},
668+ memrefElemType);
628669 auto bitCast = rewriter.create <vector::BitCastOp>(loc, storeType,
629670 fullWidthStorePart);
630671 rewriter.create <vector::StoreOp>(loc, bitCast.getResult (), memrefBase,
@@ -646,7 +687,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
646687 currentSourceIndex, remainingElements, 0 );
647688
648689 // Generate back mask.
649- auto maskValues = SmallVector<bool >(numSrcElemsPerDest , 0 );
690+ auto maskValues = SmallVector<bool >(emulatedPerContainerElem , 0 );
650691 std::fill_n (maskValues.begin (), remainingElements, 1 );
651692 auto backMask = rewriter.create <arith::ConstantOp>(
652693 loc, DenseElementsAttr::get (subWidthStoreMaskType, maskValues));
@@ -960,7 +1001,8 @@ struct ConvertVectorMaskedLoad final
9601001 // subvector at the proper offset after bit-casting.
9611002 auto origType = op.getVectorType ();
9621003 auto origElements = origType.getNumElements ();
963- bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0 ;
1004+ // Note, per-element-alignment was already verified above.
1005+ bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
9641006
9651007 auto stridedMetadata =
9661008 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -975,9 +1017,8 @@ struct ConvertVectorMaskedLoad final
9751017 getAsOpFoldResult (adaptor.getIndices ()));
9761018
9771019 std::optional<int64_t > foldedIntraVectorOffset =
978- isAlignedEmulation
979- ? 0
980- : getConstantIntValue (linearizedInfo.intraDataOffset );
1020+ isFullyAligned ? 0
1021+ : getConstantIntValue (linearizedInfo.intraDataOffset );
9811022
9821023 int64_t maxIntraDataOffset =
9831024 foldedIntraVectorOffset.value_or (emulatedPerContainerElem - 1 );
@@ -1001,7 +1042,7 @@ struct ConvertVectorMaskedLoad final
10011042 passthru = dynamicallyInsertSubVector (
10021043 rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset ,
10031044 origElements);
1004- } else if (!isAlignedEmulation ) {
1045+ } else if (!isFullyAligned ) {
10051046 passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
10061047 *foldedIntraVectorOffset);
10071048 }
@@ -1029,7 +1070,7 @@ struct ConvertVectorMaskedLoad final
10291070 mask = dynamicallyInsertSubVector (rewriter, loc, mask, emptyMask,
10301071 linearizedInfo.intraDataOffset ,
10311072 origElements);
1032- } else if (!isAlignedEmulation ) {
1073+ } else if (!isFullyAligned ) {
10331074 mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyMask,
10341075 *foldedIntraVectorOffset);
10351076 }
@@ -1040,7 +1081,7 @@ struct ConvertVectorMaskedLoad final
10401081 result = dynamicallyExtractSubVector (
10411082 rewriter, loc, result, op.getPassThru (),
10421083 linearizedInfo.intraDataOffset , origElements);
1043- } else if (!isAlignedEmulation ) {
1084+ } else if (!isFullyAligned ) {
10441085 result = staticallyExtractSubvector (
10451086 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
10461087 }
0 commit comments