@@ -519,7 +519,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
519519
520520 auto origElements = valueToStore.getType ().getNumElements ();
521521 // Note, per-element-alignment was already verified above.
522- bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
522+ bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0 ;
523523
524524 auto stridedMetadata =
525525 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -535,8 +535,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
535535 getAsOpFoldResult (adaptor.getIndices ()));
536536
537537 std::optional<int64_t > foldedNumFrontPadElems =
538- isFullyAligned ? 0
539- : getConstantIntValue (linearizedInfo.intraDataOffset );
538+ isDivisibleInSize ? 0
539+ : getConstantIntValue (linearizedInfo.intraDataOffset );
540540
541541 if (!foldedNumFrontPadElems) {
542542 return rewriter.notifyMatchFailure (
@@ -554,7 +554,7 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
554554 // need unaligned emulation because the store address is aligned and the
555555 // source is a whole byte.
556556 bool emulationRequiresPartialStores =
557- !isFullyAligned || *foldedNumFrontPadElems != 0 ;
557+ !isDivisibleInSize || *foldedNumFrontPadElems != 0 ;
558558 if (!emulationRequiresPartialStores) {
559559 // Basic case: storing full bytes.
560560 auto numElements = origElements / emulatedPerContainerElem;
@@ -881,7 +881,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
881881
882882 auto origElements = op.getVectorType ().getNumElements ();
883883 // Note, per-element-alignment was already verified above.
884- bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
884+ bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0 ;
885885
886886 auto stridedMetadata =
887887 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -897,8 +897,8 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
897897 getAsOpFoldResult (adaptor.getIndices ()));
898898
899899 std::optional<int64_t > foldedIntraVectorOffset =
900- isFullyAligned ? 0
901- : getConstantIntValue (linearizedInfo.intraDataOffset );
900+ isDivisibleInSize ? 0
901+ : getConstantIntValue (linearizedInfo.intraDataOffset );
902902
903903 // Always load enough elements which can cover the original elements.
904904 int64_t maxintraDataOffset =
@@ -915,7 +915,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
915915 result = dynamicallyExtractSubVector (
916916 rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
917917 linearizedInfo.intraDataOffset , origElements);
918- } else if (!isFullyAligned ) {
918+ } else if (!isDivisibleInSize ) {
919919 result = staticallyExtractSubvector (
920920 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
921921 }
@@ -1002,7 +1002,7 @@ struct ConvertVectorMaskedLoad final
10021002 auto origType = op.getVectorType ();
10031003 auto origElements = origType.getNumElements ();
10041004 // Note, per-element-alignment was already verified above.
1005- bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
1005+ bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0 ;
10061006
10071007 auto stridedMetadata =
10081008 rewriter.create <memref::ExtractStridedMetadataOp>(loc, op.getBase ());
@@ -1017,8 +1017,8 @@ struct ConvertVectorMaskedLoad final
10171017 getAsOpFoldResult (adaptor.getIndices ()));
10181018
10191019 std::optional<int64_t > foldedIntraVectorOffset =
1020- isFullyAligned ? 0
1021- : getConstantIntValue (linearizedInfo.intraDataOffset );
1020+ isDivisibleInSize ? 0
1021+ : getConstantIntValue (linearizedInfo.intraDataOffset );
10221022
10231023 int64_t maxIntraDataOffset =
10241024 foldedIntraVectorOffset.value_or (emulatedPerContainerElem - 1 );
@@ -1042,7 +1042,7 @@ struct ConvertVectorMaskedLoad final
10421042 passthru = dynamicallyInsertSubVector (
10431043 rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset ,
10441044 origElements);
1045- } else if (!isFullyAligned ) {
1045+ } else if (!isDivisibleInSize ) {
10461046 passthru = staticallyInsertSubvector (rewriter, loc, passthru, emptyVector,
10471047 *foldedIntraVectorOffset);
10481048 }
@@ -1070,7 +1070,7 @@ struct ConvertVectorMaskedLoad final
10701070 mask = dynamicallyInsertSubVector (rewriter, loc, mask, emptyMask,
10711071 linearizedInfo.intraDataOffset ,
10721072 origElements);
1073- } else if (!isFullyAligned ) {
1073+ } else if (!isDivisibleInSize ) {
10741074 mask = staticallyInsertSubvector (rewriter, loc, op.getMask (), emptyMask,
10751075 *foldedIntraVectorOffset);
10761076 }
@@ -1081,7 +1081,7 @@ struct ConvertVectorMaskedLoad final
10811081 result = dynamicallyExtractSubVector (
10821082 rewriter, loc, result, op.getPassThru (),
10831083 linearizedInfo.intraDataOffset , origElements);
1084- } else if (!isFullyAligned ) {
1084+ } else if (!isDivisibleInSize ) {
10851085 result = staticallyExtractSubvector (
10861086 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
10871087 }
@@ -1091,6 +1091,38 @@ struct ConvertVectorMaskedLoad final
10911091 }
10921092};
10931093
1094+ // / Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
1095+ // /
1096+ // / "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
1097+ // / vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
1098+ // / (a multi-byte scalar, e.g. i16), where N is some integer.
1099+ // /
1100+ // / Put differently, this method checks whether this would be valid:
1101+ // /
1102+ // / vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
1103+ // /
1104+ // / EXAMPLES:
1105+ // / * vector<4xi4> -> i16 - yes (N = 1)
1106+ // / * vector<4xi4> -> i8 - yes (N = 2)
1107+ // / * vector<3xi4> -> i8 - no (N would have to be 1.5)
1108+ // / * vector<3xi2> -> i16 - no (N would have to be 0.5)
1109+ static bool fitsInMultiByteContainerTy (VectorType subByteVecTy,
1110+ Type multiByteScalarTy) {
1111+ assert ((isa<IntegerType, FloatType>(multiByteScalarTy)) && " Not scalar!" );
1112+
1113+ int subByteBits = subByteVecTy.getElementType ().getIntOrFloatBitWidth ();
1114+ int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth ();
1115+
1116+ assert (subByteBits < 8 && " Not a sub-byte scalar type!" );
1117+ assert (multiByteBits % 8 == 0 && " Not a multi-byte scalar type!" );
1118+ assert (multiByteBits % subByteBits == 0 && " Unalagined element types!" );
1119+
1120+ int elemsPerMultiByte = multiByteBits / subByteBits;
1121+
1122+ // TODO: This is a bit too restrictive for vectors rank > 1.
1123+ return subByteVecTy.getShape ().back () % elemsPerMultiByte == 0 ;
1124+ }
1125+
10941126// ===----------------------------------------------------------------------===//
10951127// ConvertVectorTransferRead
10961128// ===----------------------------------------------------------------------===//
@@ -1127,7 +1159,8 @@ struct ConvertVectorTransferRead final
11271159 auto origElements = op.getVectorType ().getNumElements ();
11281160
11291161 // Note, per-element-alignment was already verified above.
1130- bool isFullyAligned = origElements % emulatedPerContainerElem == 0 ;
1162+ bool isDivisibleInSize =
1163+ fitsInMultiByteContainerTy (op.getVectorType (), containerElemTy);
11311164
11321165 auto newPadding = rewriter.create <arith::ExtUIOp>(loc, containerElemTy,
11331166 adaptor.getPadding ());
@@ -1146,8 +1179,8 @@ struct ConvertVectorTransferRead final
11461179 getAsOpFoldResult (adaptor.getIndices ()));
11471180
11481181 std::optional<int64_t > foldedIntraVectorOffset =
1149- isFullyAligned ? 0
1150- : getConstantIntValue (linearizedInfo.intraDataOffset );
1182+ isDivisibleInSize ? 0
1183+ : getConstantIntValue (linearizedInfo.intraDataOffset );
11511184
11521185 int64_t maxIntraDataOffset =
11531186 foldedIntraVectorOffset.value_or (emulatedPerContainerElem - 1 );
@@ -1171,7 +1204,7 @@ struct ConvertVectorTransferRead final
11711204 result = dynamicallyExtractSubVector (rewriter, loc, bitCast, zeros,
11721205 linearizedInfo.intraDataOffset ,
11731206 origElements);
1174- } else if (!isFullyAligned ) {
1207+ } else if (!isDivisibleInSize ) {
11751208 result = staticallyExtractSubvector (
11761209 rewriter, loc, result, *foldedIntraVectorOffset, origElements);
11771210 }
@@ -1428,41 +1461,69 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
14281461 return commonConversionPrecondition (rewriter, preconditionType, op);
14291462}
14301463
1431- // / Verify that `subByteVecType` and `dstType` are aligned. Alignment
1432- // / means that:
1433- // / 1. The `dstType` element type is a multiple of the
1434- // / `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
1435- // / is not supported). Let this multiple be `N`.
1436- // / 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
1437- // / multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
1438- // / not supported).
1464+ // / Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
1465+ // /
1466+ // / Alignment means that `subByteVecTy` can be packed into a vector of
1467+ // / `containerTy` elements. More specifically:
1468+ // / 1. The bit-width of `containerTy` is a multiple of the
1469+ // / bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
1470+ // / this multiple is 4.
1471+ // / 2. The multiple from 1. above divides evenly the number of the (trailing)
1472+ // / elements in `subByteVecTy`.
1473+ // /
1474+ // / EXAMPLE 1:
1475+ // / `subByteVecTy = vector<2xi4>`, and
1476+ // / `containerTy = i16`
1477+ // /
1478+ // / 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
1479+ // /
1480+ // / EXAMPLE 2:
1481+ // / `subByteVecTy = vector<3xi4>`, and
1482+ // / `containerTy = i16`
1483+ // /
1484+ // / 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
1485+ // /
1486+ // / EXAMPLE 3:
1487+ // / `subByteVecTy = vector<3xi3>`, and
1488+ // / `containerTy = i16`
1489+ // /
1490+ // / 16 _is not_ a multiple of 3, hence the conditions are _not met_.
14391491// /
14401492// / NOTE: This method assumes that common conversion preconditions are met. In
1441- // / particular, the element type of `dstType ` is assumed to be a multi-byte
1442- // / type (e.g. i8, i16, i32).
1493+ // / particular, `containerTy ` is assumed to be a
1494+ // / multi-byte scalar type (e.g., i8, i16, i32).
14431495static LogicalResult alignedConversionPrecondition (PatternRewriter &rewriter,
1444- VectorType subByteVecType ,
1445- VectorType dstType ,
1496+ VectorType subByteVecTy ,
1497+ Type containerTy ,
14461498 Operation *op) {
1447- if (!subByteVecType || !dstType)
1448- return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
1449- unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth ();
1450- unsigned dstElemBitwidth = dstType.getElementTypeBitWidth ();
1499+ assert (containerTy.isIntOrFloat () &&
1500+ " container element type is not a scalar" );
14511501
1452- if (dstElemBitwidth < 8 )
1453- return rewriter.notifyMatchFailure (
1454- op, " the bitwidth of dstType must be greater than or equal to 8" );
1455- if (dstElemBitwidth % srcElemBitwidth != 0 )
1456- return rewriter.notifyMatchFailure (op, " unaligned cases are not supported" );
1457- if (srcElemBitwidth != 2 && srcElemBitwidth != 4 )
1502+ // TODO: This is validating the inputs rather than checking the conditions
1503+ // documented above. Replace with an assert.
1504+ if (!subByteVecTy)
1505+ return rewriter.notifyMatchFailure (op, " not a vector!" );
1506+
1507+ unsigned subByteBits = subByteVecTy.getElementTypeBitWidth ();
1508+ unsigned containerBits = containerTy.getIntOrFloatBitWidth ();
1509+
1510+ // Enforced by the common pre-conditions.
1511+ assert (containerBits % 8 == 0 && " Not a multi-byte scalar type!" );
1512+
1513+ // TODO: Add support other widths (when/if needed)
1514+ if (subByteBits != 2 && subByteBits != 4 )
14581515 return rewriter.notifyMatchFailure (
1459- op, " only src bitwidth of 2 or 4 is supported at this moment" );
1516+ op, " only 2-bit and 4-bit sub-byte type is supported at this moment" );
1517+
1518+ // Condition 1 ("per-element" alignment)
1519+ if (containerBits % subByteBits != 0 )
1520+ return rewriter.notifyMatchFailure (op, " unalagined element types" );
14601521
1461- const int numSrcElemsPerByte = 8 / srcElemBitwidth;
1462- if ((subByteVecType. getShape (). back () % numSrcElemsPerByte) != 0 )
1522+ // Condition 2 ("full" alignment)
1523+ if (! fitsInMultiByteContainerTy (subByteVecTy, containerTy) )
14631524 return rewriter.notifyMatchFailure (
1464- op, " the trailing dimension of the input vector of sub-bytes must be a "
1465- " multiple of 8 / <sub -byte-width> " );
1525+ op, " not possible to fit this sub-byte vector type into a vector of "
1526+ " the given multi -byte type " );
14661527
14671528 return success ();
14681529}
@@ -1899,8 +1960,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
18991960 return failure ();
19001961
19011962 // Check general alignment preconditions.
1902- if (failed (alignedConversionPrecondition (rewriter, srcVecType, dstVecType,
1903- conversionOp)))
1963+ if (failed (alignedConversionPrecondition (
1964+ rewriter, srcVecType,
1965+ /* containerTy=*/ rewriter.getI8Type (), conversionOp)))
19041966 return failure ();
19051967
19061968 // Perform the rewrite.
@@ -1964,8 +2026,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
19642026
19652027 // Check general alignment preconditions. We invert the src/dst type order
19662028 // to reuse the existing precondition logic.
1967- if (failed (alignedConversionPrecondition (rewriter, dstVecType, srcVecType,
1968- truncOp)))
2029+ if (failed (alignedConversionPrecondition (
2030+ rewriter, dstVecType,
2031+ /* containerTy=*/ rewriter.getI8Type (), truncOp)))
19692032 return failure ();
19702033
19712034 // Create a new iX -> i8 truncation op.
0 commit comments