@@ -415,18 +415,21 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
415415 " only 1-D vectors are supported ATM" );
416416
417417 auto loc = op.getLoc ();
418+
418419 auto valueToStore = cast<VectorValue>(op.getValueToStore ());
419- auto oldElementType = valueToStore.getType ().getElementType ();
420- auto newElementType =
420+ auto containerElemTy =
421421 cast<MemRefType>(adaptor.getBase ().getType ()).getElementType ();
422- int srcBits = oldElementType.getIntOrFloatBitWidth ();
423- int dstBits = newElementType.getIntOrFloatBitWidth ();
422+ Type emulatedElemTy = op.getValueToStore ().getType ().getElementType ();
423+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth ();
424+ int containerBits = containerElemTy.getIntOrFloatBitWidth ();
424425
425- if (dstBits % srcBits != 0 ) {
426+ // Check per-element alignment.
427+ if (containerBits % emulatedBits != 0 ) {
426428 return rewriter.notifyMatchFailure (
427- op, " only dstBits % srcBits == 0 supported" );
429+ op, " impossible to pack emulated elements into container elements "
430+ " (bit-wise misalignment)" );
428431 }
429- int numSrcElemsPerDest = dstBits / srcBits ;
432+ int numSrcElemsPerDest = containerBits / emulatedBits ;
430433
431434 // Adjust the number of elements to store when emulating narrow types.
432435 // Here only the 1-D vector store is considered, and the N-D memref types
@@ -451,7 +454,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
451454 memref::LinearizedMemRefInfo linearizedInfo;
452455 std::tie (linearizedInfo, linearizedIndices) =
453456 memref::getLinearizedMemRefOffsetAndSize (
454- rewriter, loc, srcBits, dstBits ,
457+ rewriter, loc, emulatedBits, containerBits ,
455458 stridedMetadata.getConstifiedMixedOffset (),
456459 stridedMetadata.getConstifiedMixedSizes (),
457460 stridedMetadata.getConstifiedMixedStrides (),
@@ -483,7 +486,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
483486 // Basic case: storing full bytes.
484487 auto numElements = origElements / numSrcElemsPerDest;
485488 auto bitCast = rewriter.create <vector::BitCastOp>(
486- loc, VectorType::get (numElements, newElementType ),
489+ loc, VectorType::get (numElements, containerElemTy ),
487490 op.getValueToStore ());
488491 rewriter.replaceOpWithNewOp <vector::StoreOp>(
489492 op, bitCast.getResult (), memrefBase,
@@ -638,18 +641,20 @@ struct ConvertVectorMaskedStore final
638641 " only 1-D vectors are supported ATM" );
639642
640643 auto loc = op.getLoc ();
641- auto convertedType = cast<MemRefType>(adaptor. getBase (). getType ());
642- Type oldElementType = op. getValueToStore ().getType ().getElementType ();
643- Type newElementType = convertedType .getElementType ();
644- int srcBits = oldElementType .getIntOrFloatBitWidth ();
645- int dstBits = newElementType .getIntOrFloatBitWidth ();
644+ auto containerElemTy =
645+ cast<MemRefType>(adaptor. getBase ().getType () ).getElementType ();
646+ Type emulatedElemTy = op. getValueToStore (). getType () .getElementType ();
647+ int emulatedBits = emulatedElemTy .getIntOrFloatBitWidth ();
648+ int containerBits = containerElemTy .getIntOrFloatBitWidth ();
646649
647- if (dstBits % srcBits != 0 ) {
650+ // Check per-element alignment.
651+ if (containerBits % emulatedBits != 0 ) {
648652 return rewriter.notifyMatchFailure (
649- op, " only dstBits % srcBits == 0 supported" );
653+ op, " impossible to pack emulated elements into container elements "
654+ " (bit-wise misalignment)" );
650655 }
651656
652- int scale = dstBits / srcBits ;
657+ int scale = containerBits / emulatedBits ;
653658 int origElements = op.getValueToStore ().getType ().getNumElements ();
654659 if (origElements % scale != 0 )
655660 return failure ();
@@ -660,7 +665,7 @@ struct ConvertVectorMaskedStore final
660665 memref::LinearizedMemRefInfo linearizedInfo;
661666 std::tie (linearizedInfo, linearizedIndicesOfr) =
662667 memref::getLinearizedMemRefOffsetAndSize (
663- rewriter, loc, srcBits, dstBits ,
668+ rewriter, loc, emulatedBits, containerBits ,
664669 stridedMetadata.getConstifiedMixedOffset (),
665670 stridedMetadata.getConstifiedMixedSizes (),
666671 stridedMetadata.getConstifiedMixedStrides (),
@@ -706,15 +711,15 @@ struct ConvertVectorMaskedStore final
706711 return failure ();
707712
708713 auto numElements = (origElements + scale - 1 ) / scale;
709- auto newType = VectorType::get (numElements, newElementType );
714+ auto newType = VectorType::get (numElements, containerElemTy );
710715 auto passThru = rewriter.create <arith::ConstantOp>(
711716 loc, newType, rewriter.getZeroAttr (newType));
712717
713718 auto newLoad = rewriter.create <vector::MaskedLoadOp>(
714719 loc, newType, adaptor.getBase (), linearizedIndices,
715720 newMask.value ()->getResult (0 ), passThru);
716721
717- auto newBitCastType = VectorType::get (numElements * scale, oldElementType );
722+ auto newBitCastType = VectorType::get (numElements * scale, emulatedElemTy );
718723 Value valueToStore =
719724 rewriter.create <vector::BitCastOp>(loc, newBitCastType, newLoad);
720725 valueToStore = rewriter.create <arith::SelectOp>(
@@ -746,17 +751,19 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
746751 " only 1-D vectors are supported ATM" );
747752
748753 auto loc = op.getLoc ();
749- auto convertedType = cast<MemRefType>(adaptor. getBase (). getType ());
750- Type oldElementType = op. getType ().getElementType ();
751- Type newElementType = convertedType .getElementType ();
752- int srcBits = oldElementType .getIntOrFloatBitWidth ();
753- int dstBits = newElementType .getIntOrFloatBitWidth ();
754+ auto containerElemTy =
755+ cast<MemRefType>(adaptor. getBase (). getType () ).getElementType ();
756+ Type emulatedElemTy = op. getType () .getElementType ();
757+ int emulatedBits = emulatedElemTy .getIntOrFloatBitWidth ();
758+ int containerBits = containerElemTy .getIntOrFloatBitWidth ();
754759
755- if (dstBits % srcBits != 0 ) {
760+ // Check per-element alignment.
761+ if (containerBits % emulatedBits != 0 ) {
756762 return rewriter.notifyMatchFailure (
757- op, " only dstBits % srcBits == 0 supported" );
763+ op, " impossible to pack emulated elements into container elements "
764+ " (bit-wise misalignment)" );
758765 }
759- int scale = dstBits / srcBits ;
766+ int scale = containerBits / emulatedBits ;
760767
761768 // Adjust the number of elements to load when emulating narrow types,
762769 // and then cast back to the original type with vector.bitcast op.
@@ -797,7 +804,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
797804 memref::LinearizedMemRefInfo linearizedInfo;
798805 std::tie (linearizedInfo, linearizedIndices) =
799806 memref::getLinearizedMemRefOffsetAndSize (
800- rewriter, loc, srcBits, dstBits ,
807+ rewriter, loc, emulatedBits, containerBits ,
801808 stridedMetadata.getConstifiedMixedOffset (),
802809 stridedMetadata.getConstifiedMixedSizes (),
803810 stridedMetadata.getConstifiedMixedStrides (),
@@ -814,7 +821,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
814821 llvm::divideCeil (maxintraDataOffset + origElements, scale);
815822 Value result =
816823 emulatedVectorLoad (rewriter, loc, adaptor.getBase (), linearizedIndices,
817- numElements, oldElementType, newElementType );
824+ numElements, emulatedElemTy, containerElemTy );
818825
819826 if (!foldedIntraVectorOffset) {
820827 auto resultVector = rewriter.create <arith::ConstantOp>(
@@ -848,17 +855,20 @@ struct ConvertVectorMaskedLoad final
848855 " only 1-D vectors are supported ATM" );
849856
850857 auto loc = op.getLoc ();
851- auto convertedType = cast<MemRefType>(adaptor.getBase ().getType ());
852- Type oldElementType = op.getType ().getElementType ();
853- Type newElementType = convertedType.getElementType ();
854- int srcBits = oldElementType.getIntOrFloatBitWidth ();
855- int dstBits = newElementType.getIntOrFloatBitWidth ();
856858
857- if (dstBits % srcBits != 0 ) {
859+ auto containerElemTy =
860+ cast<MemRefType>(adaptor.getBase ().getType ()).getElementType ();
861+ Type emulatedElemTy = op.getType ().getElementType ();
862+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth ();
863+ int containerBits = containerElemTy.getIntOrFloatBitWidth ();
864+
865+ // Check per-element alignment.
866+ if (containerBits % emulatedBits != 0 ) {
858867 return rewriter.notifyMatchFailure (
859- op, " only dstBits % srcBits == 0 supported" );
868+ op, " impossible to pack emulated elements into container elements "
869+ " (bit-wise misalignment)" );
860870 }
861- int scale = dstBits / srcBits ;
871+ int scale = containerBits / emulatedBits ;
862872
863873 // Adjust the number of elements to load when emulating narrow types,
864874 // and then cast back to the original type with vector.bitcast op.
@@ -912,7 +922,7 @@ struct ConvertVectorMaskedLoad final
912922 memref::LinearizedMemRefInfo linearizedInfo;
913923 std::tie (linearizedInfo, linearizedIndices) =
914924 memref::getLinearizedMemRefOffsetAndSize (
915- rewriter, loc, srcBits, dstBits ,
925+ rewriter, loc, emulatedBits, containerBits ,
916926 stridedMetadata.getConstifiedMixedOffset (),
917927 stridedMetadata.getConstifiedMixedSizes (),
918928 stridedMetadata.getConstifiedMixedStrides (),
@@ -933,8 +943,8 @@ struct ConvertVectorMaskedLoad final
933943
934944 auto numElements =
935945 llvm::divideCeil (maxIntraDataOffset + origElements, scale);
936- auto loadType = VectorType::get (numElements, newElementType );
937- auto newBitcastType = VectorType::get (numElements * scale, oldElementType );
946+ auto loadType = VectorType::get (numElements, containerElemTy );
947+ auto newBitcastType = VectorType::get (numElements * scale, emulatedElemTy );
938948
939949 auto emptyVector = rewriter.create <arith::ConstantOp>(
940950 loc, newBitcastType, rewriter.getZeroAttr (newBitcastType));
@@ -1009,23 +1019,25 @@ struct ConvertVectorTransferRead final
10091019 " only 1-D vectors are supported ATM" );
10101020
10111021 auto loc = op.getLoc ();
1012- auto convertedType = cast<MemRefType>(adaptor.getSource ().getType ());
1013- Type oldElementType = op.getType ().getElementType ();
1014- Type newElementType = convertedType.getElementType ();
1015- int srcBits = oldElementType.getIntOrFloatBitWidth ();
1016- int dstBits = newElementType.getIntOrFloatBitWidth ();
1017-
1018- if (dstBits % srcBits != 0 ) {
1022+ auto containerElemTy =
1023+ cast<MemRefType>(adaptor.getSource ().getType ()).getElementType ();
1024+ Type emulatedElemTy = op.getType ().getElementType ();
1025+ int emulatedBits = emulatedElemTy.getIntOrFloatBitWidth ();
1026+ int containerBits = containerElemTy.getIntOrFloatBitWidth ();
1027+
1028+ // Check per-element alignment.
1029+ if (containerBits % emulatedBits != 0 ) {
10191030 return rewriter.notifyMatchFailure (
1020- op, " only dstBits % srcBits == 0 supported" );
1031+ op, " impossible to pack emulated elements into container elements "
1032+ " (bit-wise misalignment)" );
10211033 }
1022- int scale = dstBits / srcBits ;
1034+ int scale = containerBits / emulatedBits ;
10231035
10241036 auto origElements = op.getVectorType ().getNumElements ();
10251037
10261038 bool isAlignedEmulation = origElements % scale == 0 ;
10271039
1028- auto newPadding = rewriter.create <arith::ExtUIOp>(loc, newElementType ,
1040+ auto newPadding = rewriter.create <arith::ExtUIOp>(loc, containerElemTy ,
10291041 adaptor.getPadding ());
10301042
10311043 auto stridedMetadata =
@@ -1035,7 +1047,7 @@ struct ConvertVectorTransferRead final
10351047 memref::LinearizedMemRefInfo linearizedInfo;
10361048 std::tie (linearizedInfo, linearizedIndices) =
10371049 memref::getLinearizedMemRefOffsetAndSize (
1038- rewriter, loc, srcBits, dstBits ,
1050+ rewriter, loc, emulatedBits, containerBits ,
10391051 stridedMetadata.getConstifiedMixedOffset (),
10401052 stridedMetadata.getConstifiedMixedSizes (),
10411053 stridedMetadata.getConstifiedMixedStrides (),
@@ -1051,12 +1063,12 @@ struct ConvertVectorTransferRead final
10511063 llvm::divideCeil (maxIntraDataOffset + origElements, scale);
10521064
10531065 auto newRead = rewriter.create <vector::TransferReadOp>(
1054- loc, VectorType::get (numElements, newElementType ), adaptor.getSource (),
1066+ loc, VectorType::get (numElements, containerElemTy ), adaptor.getSource (),
10551067 getValueOrCreateConstantIndexOp (rewriter, loc, linearizedIndices),
10561068 newPadding);
10571069
10581070 auto bitCast = rewriter.create <vector::BitCastOp>(
1059- loc, VectorType::get (numElements * scale, oldElementType ), newRead);
1071+ loc, VectorType::get (numElements * scale, emulatedElemTy ), newRead);
10601072
10611073 Value result = bitCast->getResult (0 );
10621074 if (!foldedIntraVectorOffset) {
0 commit comments