@@ -400,6 +400,9 @@ namespace {
400400// ConvertVectorStore
401401// ===----------------------------------------------------------------------===//
402402
403+ // /
404+ // /
405+
403406struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
404407 using OpConversionPattern::OpConversionPattern;
405408
@@ -443,7 +446,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
443446 // vector<4xi8>
444447
445448 auto origElements = valueToStore.getType ().getNumElements ();
446- bool isUnalignedEmulation = origElements % numSrcElemsPerDest ! = 0 ;
449+ bool isAlignedEmulation = origElements % numSrcElemsPerDest = = 0 ;
447450
448451 auto stridedMetadata =
449452 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -459,9 +462,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
459462 getAsOpFoldResult (adaptor.getIndices ()));
460463
461464 std::optional<int64_t > foldedNumFrontPadElems =
462- isUnalignedEmulation
463- ? getConstantIntValue (linearizedInfo. intraDataOffset )
464- : 0 ;
465+ isAlignedEmulation
466+ ? 0
467+ : getConstantIntValue (linearizedInfo. intraDataOffset ) ;
465468
466469 if (!foldedNumFrontPadElems) {
467470 return failure (" subbyte store emulation: dynamic front padding size is "
@@ -472,13 +475,13 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
472475
473476 // Shortcut: conditions when subbyte emulated store at the front is not
474477 // needed:
475- // 1. The source vector size is multiple of byte size
476- // 2. The address of the store is aligned to the emulated width boundary
478+ // 1. The source vector size (in bits) is a multiple of byte size.
479+ // 2. The address of the store is aligned to the emulated width boundary.
477480 //
478481 // For example, to store a vector<4xi2> to <13xi2> at offset 4, does not
479482 // need unaligned emulation because the store address is aligned and the
480483 // source is a whole byte.
481- if (!isUnalignedEmulation && *foldedNumFrontPadElems == 0 ) {
484+ if (isAlignedEmulation && *foldedNumFrontPadElems == 0 ) {
482485 auto numElements = origElements / numSrcElemsPerDest;
483486 auto bitCast = rewriter.create <vector::BitCastOp>(
484487 loc, VectorType::get (numElements, newElementType),
@@ -489,17 +492,50 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
489492 return success ();
490493 }
491494
492- // The index into the target memref we are storing to
495+ // Next, handle the case when sub-byte read-modify-write
496+ // sequences are needed to emulate a vector store.
497+ // Here is an example:
498+ //
499+ // Vector to store: vector<7xi2>
500+ // Value to store: 11 11 11 11 11 11 11 (all ones)
501+ //
502+ // Destination: memref<12xi2>
503+ // Store offset: 2 (i.e. 4 bits into the 1st emulated byte).
504+ //
505+ // MLIR: vector.store %val, %dest[%c2] : memref<12xi2>, vector<7xi2>
506+ //
507+ // Destination memref before:
508+ //
509+ // Byte 0 Byte 1 Byte 2
510+ // +----------+----------+----------+
511+ // | 00000000 | 00000000 | 00000000 |
512+ // +----------+----------+----------+
513+ //
514+ // Destination memref after:
515+ //
516+ // Byte 0 Byte 1 Byte 2
517+ // +----------+----------+----------+
518+ // | 00001111 | 11111111 | 11000000 |
519+ // +----------+----------+----------+
520+ //
521+ // Note, stores to Byte 1 are "full-width" and hence don't require RMW (no
522+ // need for atomicity). Stores to Bytes 0 and Byte 2 are "partial", hence
523+ // requiring RMW access (atomicity is required).
524+
525+ // The index into the target memref we are storing to.
493526 Value currentDestIndex =
494527 getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices);
528+ // The index into the source vector we are currently processing.
529+ auto currentSourceIndex = 0 ;
530+
531+ // Build a mask used for rmw.
495532 auto subWidthStoreMaskType =
496533 VectorType::get ({numSrcElemsPerDest}, rewriter.getI1Type ());
497- // The index into the source vector we are currently processing
498- auto currentSourceIndex = 0 ;
499534
500- // 1. Partial width store for the first byte, when the store address is not
501- // aligned to emulated width boundary, deal with the unaligned part so that
502- // the rest elements are aligned to width boundary.
535+ // 1. Partial width store for the leading byte.
536+ // When the store address is not aligned to emulated width boundary, deal
537+ // with the unaligned part so that the rest elements are aligned to width
538+ // boundary.
503539 auto frontSubWidthStoreElem =
504540 (numSrcElemsPerDest - *foldedNumFrontPadElems) % numSrcElemsPerDest;
505541 if (frontSubWidthStoreElem > 0 ) {
@@ -535,8 +571,9 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
535571 currentDestIndex = rewriter.create <arith::AddIOp>(
536572 loc, rewriter.getIndexType (), currentDestIndex, constantOne);
537573
538- // 2. Full width store. After the previous step, the store address is
539- // aligned to the emulated width boundary.
574+ // 2. Full width store for the inner output bytes.
575+ // After the previous step, the store address is aligned to the emulated
576+ // width boundary.
540577 int64_t fullWidthStoreSize =
541578 (origElements - currentSourceIndex) / numSrcElemsPerDest;
542579 int64_t numNonFullWidthElements = fullWidthStoreSize * numSrcElemsPerDest;
@@ -560,15 +597,16 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
560597 rewriter.create <arith::ConstantIndexOp>(loc, fullWidthStoreSize));
561598 }
562599
563- // 3. Deal with trailing elements that are aligned to the emulated width,
564- // but their length is smaller than the emulated width.
600+ // 3. Partial width store for the trailing output byte.
601+ // It is needed when the residual length is smaller than the emulated width,
602+ // which is not covered in step 2 above.
565603 auto remainingElements = origElements - currentSourceIndex;
566604 if (remainingElements != 0 ) {
567605 auto subWidthStorePart =
568606 extractSliceIntoByte (rewriter, loc, cast<VectorValue>(valueToStore),
569607 currentSourceIndex, remainingElements, 0 );
570608
571- // Generate back mask
609+ // Generate back mask.
572610 auto maskValues = SmallVector<bool >(numSrcElemsPerDest, 0 );
573611 std::fill_n (maskValues.begin (), remainingElements, 1 );
574612 auto backMask = rewriter.create <arith::ConstantOp>(
@@ -751,7 +789,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
751789 // compile time as they must be constants.
752790
753791 auto origElements = op.getVectorType ().getNumElements ();
754- bool isUnalignedEmulation = origElements % scale ! = 0 ;
792+ bool isAlignedEmulation = origElements % scale = = 0 ;
755793
756794 auto stridedMetadata =
757795 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -767,9 +805,9 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
767805 getAsOpFoldResult (adaptor.getIndices ()));
768806
769807 std::optional<int64_t > foldedIntraVectorOffset =
770- isUnalignedEmulation
771- ? getConstantIntValue (linearizedInfo. intraDataOffset )
772- : 0 ;
808+ isAlignedEmulation
809+ ? 0
810+ : getConstantIntValue (linearizedInfo. intraDataOffset ) ;
773811
774812 // Always load enough elements which can cover the original elements.
775813 int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
@@ -785,7 +823,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
785823 result = dynamicallyExtractSubVector (
786824 rewriter, loc, cast<VectorValue>(result), resultVector,
787825 linearizedInfo.intraDataOffset , origElements);
788- } else if (isUnalignedEmulation ) {
826+ } else if (!isAlignedEmulation ) {
789827 result = staticallyExtractSubvector (
790828 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
791829 }
@@ -867,7 +905,7 @@ struct ConvertVectorMaskedLoad final
867905 // subvector at the proper offset after bit-casting.
868906 auto origType = op.getVectorType ();
869907 auto origElements = origType.getNumElements ();
870- bool isUnalignedEmulation = origElements % scale ! = 0 ;
908+ bool isAlignedEmulation = origElements % scale = = 0 ;
871909
872910 auto stridedMetadata =
873911 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -882,9 +920,9 @@ struct ConvertVectorMaskedLoad final
882920 getAsOpFoldResult (adaptor.getIndices ()));
883921
884922 std::optional<int64_t > foldedIntraVectorOffset =
885- isUnalignedEmulation
886- ? getConstantIntValue (linearizedInfo. intraDataOffset )
887- : 0 ;
923+ isAlignedEmulation
924+ ? 0
925+ : getConstantIntValue (linearizedInfo. intraDataOffset ) ;
888926
889927 int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
890928 FailureOr<Operation *> newMask = getCompressedMaskOp (
@@ -905,7 +943,7 @@ struct ConvertVectorMaskedLoad final
905943 passthru = dynamicallyInsertSubVector (
906944 rewriter, loc, cast<VectorValue>(passthru), emptyVector,
907945 linearizedInfo.intraDataOffset , origElements);
908- } else if (isUnalignedEmulation ) {
946+ } else if (!isAlignedEmulation ) {
909947 passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
910948 *foldedIntraVectorOffset);
911949 }
@@ -933,7 +971,7 @@ struct ConvertVectorMaskedLoad final
933971 mask = dynamicallyInsertSubVector (
934972 rewriter, loc, cast<VectorValue>(mask), emptyMask,
935973 linearizedInfo.intraDataOffset , origElements);
936- } else if (isUnalignedEmulation ) {
974+ } else if (!isAlignedEmulation ) {
937975 mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyMask,
938976 *foldedIntraVectorOffset);
939977 }
@@ -944,7 +982,7 @@ struct ConvertVectorMaskedLoad final
944982 result = dynamicallyExtractSubVector (
945983 rewriter, loc, cast<VectorValue>(result), op.getPassThru (),
946984 linearizedInfo.intraDataOffset , origElements);
947- } else if (isUnalignedEmulation ) {
985+ } else if (!isAlignedEmulation ) {
948986 result = staticallyExtractSubvector (
949987 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
950988 }
@@ -986,7 +1024,7 @@ struct ConvertVectorTransferRead final
9861024
9871025 auto origElements = op.getVectorType ().getNumElements ();
9881026
989- bool isUnalignedEmulation = origElements % scale ! = 0 ;
1027+ bool isAlignedEmulation = origElements % scale = = 0 ;
9901028
9911029 auto newPadding = rewriter.create <arith::ExtUIOp>(loc, newElementType,
9921030 adaptor.getPadding ());
@@ -1005,9 +1043,9 @@ struct ConvertVectorTransferRead final
10051043 getAsOpFoldResult (adaptor.getIndices ()));
10061044
10071045 std::optional<int64_t > foldedIntraVectorOffset =
1008- isUnalignedEmulation
1009- ? getConstantIntValue (linearizedInfo. intraDataOffset )
1010- : 0 ;
1046+ isAlignedEmulation
1047+ ? 0
1048+ : getConstantIntValue (linearizedInfo. intraDataOffset ) ;
10111049
10121050 int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or (scale - 1 );
10131051 auto numElements =
@@ -1028,7 +1066,7 @@ struct ConvertVectorTransferRead final
10281066 result = dynamicallyExtractSubVector (rewriter, loc, bitCast, zeros,
10291067 linearizedInfo.intraDataOffset ,
10301068 origElements);
1031- } else if (isUnalignedEmulation ) {
1069+ } else if (!isAlignedEmulation ) {
10321070 result = staticallyExtractSubvector (
10331071 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
10341072 }
0 commit comments