@@ -1090,15 +1090,20 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
10901090 unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth ();
10911091 unsigned dstElemBitwidth = dstType.getElementTypeBitWidth ();
10921092
1093- // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
1094- if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
1095- (dstElemBitwidth % srcElemBitwidth) != 0 )
1096- return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
1093+ if (dstElemBitwidth < 8 )
1094+ return rewriter.notifyMatchFailure (
1095+ op, " the bitwidth of dstType must be greater than or equal to 8" );
1096+ if (dstElemBitwidth % srcElemBitwidth != 0 )
1097+ return rewriter.notifyMatchFailure (op, " unaligned cases are not supported" );
1098+ if (srcElemBitwidth != 2 && srcElemBitwidth != 4 )
1099+ return rewriter.notifyMatchFailure (
1100+ op, " only src bitwidth of 2 or 4 is supported at this moment" );
10971101
1098- const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
1099- if ((subByteVecType.getShape ().back () % numSrcElemsPerDestElem ) != 0 )
1102+ const int numSrcElemsPerByte = 8 / srcElemBitwidth;
1103+ if ((subByteVecType.getShape ().back () % numSrcElemsPerByte ) != 0 )
11001104 return rewriter.notifyMatchFailure (
1101- op, " Not an even number of i4 elements in trailing dim" );
1105+ op, " the trailing dimension of the input vector of sub-bytes must be a "
1106+ " multiple of 8 / <sub-byte-width>" );
11021107
11031108 return success ();
11041109}
@@ -1179,70 +1184,166 @@ Value BitCastRewriter::genericRewriteStep(
11791184 return runningResult;
11801185}
11811186
1182- // / Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
1183- // / bitwise ops that take advantage of high-level information to avoid leaving
1184- // / LLVM to scramble with peephole optimizations.
1185- static Value rewriteI4ToI8SignedExt (PatternRewriter &rewriter, Location loc,
1186- Value srcValue) {
1187- VectorType srcVecType = cast<VectorType>(srcValue.getType ());
1188- assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
1189- " Expected i4 type" );
1187+ // / Bitcasts the aligned `subByteVec` vector to a vector of i8.
1188+ // / Where aligned means it satisfies the alignedConversionPreconditions.
1189+ // /
1190+ // / Example:
1191+ // / vector<16x16xi2> -> vector<16x4xi8>
1192+ // / vector<16x16xi4> -> vector<16x8xi8>
1193+ static Value bitcastSubByteVectorToI8 (PatternRewriter &rewriter, Location loc,
1194+ Value subByteVec) {
1195+ auto srcVecType = cast<VectorType>(subByteVec.getType ());
1196+ int64_t srcBitwidth = srcVecType.getElementType ().getIntOrFloatBitWidth ();
1197+ assert (8 % srcBitwidth == 0 &&
1198+ " Unsupported sub-byte type (not a divisor of i8)" );
1199+ int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1200+ SmallVector<int64_t > vecShape (srcVecType.getShape ());
1201+ // Adjust last dimension of the vector, so the total size remains the same.
1202+ vecShape.back () = vecShape.back () / numSrcElemsPerByte;
1203+ auto i8VecType = VectorType::get (vecShape, rewriter.getI8Type ());
1204+ return rewriter.create <vector::BitCastOp>(loc, i8VecType, subByteVec);
1205+ }
11901206
1191- // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1192- SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
1193- constexpr int64_t i4Toi8BitwidthFactor = 2 ;
1194- i8VecShape.back () = i8VecShape.back () / i4Toi8BitwidthFactor;
1195- auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
1196- Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
1207+ // / Extracts a signed N-bit sequence from each element of a vector of bytes,
1208+ // / starting at the specified bit index.
1209+ // / The `bitIdx` starts at 0 from the LSB and moves to the left.
1210+ // /
1211+ // / Example for a single element:
1212+ // / Extract numBits=2 starting at bitIdx=2
1213+ // / src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
1214+ // / indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1215+ // / target = [. . . . ^ ^ . .]
1216+ // /
1217+ // / The target sequence is [11](decimal=-1) as signed 2-bit integer.
1218+ // / So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
1219+ // /
1220+ // / src = [01 01 11 10]
1221+ // / shl = arith.shl(src, 4) -> [11 10 00 00]
1222+ // / result = arith.shrsi(shl, 6) -> [11 11 11 11]
1223+ static Value extractNBitsPerByteAndSignExtendToI8 (PatternRewriter &rewriter,
1224+ Location loc, Value src,
1225+ int bitIdx, int numBits) {
1226+ auto srcType = cast<VectorType>(src.getType ());
1227+ Value shl = src;
1228+ int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1229+ assert (bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1230+ " Invalid bitIdx range" );
1231+ if (bitsToShiftLeft != 0 ) {
1232+ Value shiftLeftValues = rewriter.create <arith::ConstantOp>(
1233+ loc, DenseElementsAttr::get (srcType, bitsToShiftLeft));
1234+ shl = rewriter.create <arith::ShLIOp>(loc, src, shiftLeftValues);
1235+ }
11971236
1198- // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
1199- // byte are place in one vector and the high i4 elements in another vector.
1200- constexpr int8_t bitsToShift = 4 ;
1201- auto shiftValues = rewriter.create <arith::ConstantOp>(
1202- loc, DenseElementsAttr::get (i8VecType, bitsToShift));
1203- Value shl = rewriter.create <arith::ShLIOp>(loc, i8Vector, shiftValues);
1204- Value low = rewriter.create <arith::ShRSIOp>(loc, shl, shiftValues);
1205- Value high = rewriter.create <arith::ShRSIOp>(loc, i8Vector, shiftValues);
1237+ int8_t bitsToShiftRight = 8 - numBits;
1238+ Value shiftRightValues = rewriter.create <arith::ConstantOp>(
1239+ loc, DenseElementsAttr::get (srcType, bitsToShiftRight));
1240+ Value shr = rewriter.create <arith::ShRSIOp>(loc, shl, shiftRightValues);
1241+ return shr;
1242+ }
12061243
1207- // 3. Interleave low and high i8 elements.
1208- return rewriter.create <vector::InterleaveOp>(loc, low, high);
1244+ // / Extracts an unsigned N-bit sequence from each element of a vector of bytes,
1245+ // / starting at the specified bit index.
1246+ // / The `bitIdx` starts at 0 from the LSB and moves to the left.
1247+ // /
1248+ // / Example for a single element:
1249+ // / Extract numBits=2 starting at bitIdx=2
1250+ // / src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
1251+ // / indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1252+ // / target = [. . . . ^ ^ . .]
1253+ // /
1254+ // / The target sequence is [10](decimal=2) as unsigned 2-bit integer.
1255+ // / So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
1256+ // /
1257+ // / src = [01 01 10 10]
1258+ // / mask = [00 00 00 11]
1259+ // / shr = arith.shrui(src, 2) = [00 01 01 10]
1260+ // / result = arith.andi(shr, mask) = [00 00 00 10]
1261+ // / NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be
1262+ // / achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking.
1263+ // / However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift
1264+ // / left when the index is 0.
1265+ static Value extractNBitsPerByteAndExtendToI8 (PatternRewriter &rewriter,
1266+ Location loc, Value src,
1267+ int bitIdx, int numBits) {
1268+ assert (bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1269+ " Invalid bitIdx range" );
1270+ auto srcType = cast<VectorType>(src.getType ());
1271+ int8_t bitsToShiftRight = bitIdx;
1272+ Value shr = src;
1273+ if (bitsToShiftRight != 0 ) {
1274+ Value shiftRightValues = rewriter.create <arith::ConstantOp>(
1275+ loc, DenseElementsAttr::get (srcType, bitsToShiftRight));
1276+ shr = rewriter.create <arith::ShRUIOp>(loc, src, shiftRightValues);
1277+ }
1278+ if (bitIdx + numBits == 8 ) {
1279+ return shr;
1280+ }
1281+ uint8_t lowBitsMask = (1 << numBits) - 1 ;
1282+ Value lowBitsMaskValues = rewriter.create <arith::ConstantOp>(
1283+ loc, DenseElementsAttr::get (srcType, lowBitsMask));
1284+ return rewriter.create <arith::AndIOp>(loc, shr, lowBitsMaskValues);
12091285}
12101286
1211- // / Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
1212- // / bitwise ops that take advantage of high-level information to avoid leaving
1213- // / LLVM to scramble with peephole optimizations.
1214- static Value rewriteI4ToI8UnsignedExt (PatternRewriter &rewriter, Location loc,
1215- Value srcValue) {
1216- VectorType srcVecType = cast<VectorType>(srcValue.getType ());
1287+ using ExtractNBitsFn =
1288+ std::function<Value(PatternRewriter &, Location, Value, int , int )>;
1289+
1290+ // / Rewrite the i4 -> i8 extension into a sequence of shuffles and
1291+ // / bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1292+ static Value rewriteI4ToI8Ext (PatternRewriter &rewriter, Location loc,
1293+ Value srcValue, const ExtractNBitsFn &extFn) {
1294+ auto srcVecType = cast<VectorType>(srcValue.getType ());
12171295 assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
12181296 " Expected i4 type" );
12191297
12201298 // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1221- SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
1222- constexpr int64_t i4Toi8BitwidthFactor = 2 ;
1223- i8VecShape.back () = i8VecShape.back () / i4Toi8BitwidthFactor;
1224- auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
1225- Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
1226-
1227- // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
1228- // byte are placed in one vector and the high i4 elements in another vector.
1229- constexpr uint8_t lowBitsMask = 15 ; // Equivalent to [00001111] bit mask
1230- auto lowBitsMaskValues = rewriter.create <arith::ConstantOp>(
1231- loc, DenseElementsAttr::get (i8VecType, lowBitsMask));
1232- Value low = rewriter.create <arith::AndIOp>(loc, i8VecType, i8Vector,
1233- lowBitsMaskValues);
1234- constexpr int8_t highBitsToShift = 4 ;
1235- auto highShiftValues = rewriter.create <arith::ConstantOp>(
1236- loc, DenseElementsAttr::get (i8VecType, highBitsToShift));
1237- Value high = rewriter.create <arith::ShRUIOp>(loc, i8Vector, highShiftValues);
1299+ Value i8Vector = bitcastSubByteVectorToI8 (rewriter, loc, srcValue);
1300+
1301+ // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
1302+ // byte are place in one vector and the high i4 elements in another vector.
1303+ Value low = extFn (rewriter, loc, i8Vector, 0 , 4 );
1304+ Value high = extFn (rewriter, loc, i8Vector, 4 , 4 );
12381305
12391306 // 3. Interleave low and high i8 elements.
12401307 return rewriter.create <vector::InterleaveOp>(loc, low, high);
12411308}
12421309
1310+ // / Rewrite the i2 -> i8 extension into a sequence of shuffles and
1311+ // / bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1312+ static Value rewriteI2ToI8Ext (PatternRewriter &rewriter, Location loc,
1313+ Value srcValue, const ExtractNBitsFn &extFn) {
1314+ VectorType srcVecType = cast<VectorType>(srcValue.getType ());
1315+ assert (srcVecType.getElementType ().isSignlessInteger (2 ) &&
1316+ " Expected i2 type" );
1317+
1318+ // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1319+ Value i8Vector = bitcastSubByteVectorToI8 (rewriter, loc, srcValue);
1320+
1321+ // 2. Extract each i2 element
1322+ // Positon 0 (bits 0-1)
1323+ Value vec0 = extFn (rewriter, loc, i8Vector, 0 , 2 );
1324+ // Position 1 (bits 2-3)
1325+ Value vec1 = extFn (rewriter, loc, i8Vector, 2 , 2 );
1326+ // Position 2 (bits 4-5)
1327+ Value vec2 = extFn (rewriter, loc, i8Vector, 4 , 2 );
1328+ // Position 3 (bits 6-7)
1329+ Value vec3 = extFn (rewriter, loc, i8Vector, 6 , 2 );
1330+
1331+ // 3. Interleave all 4 elements by first interleaving
1332+ // even elements and then odd
1333+ // vec0 = [0,0,0,0],...
1334+ // vec1 = [1,1,1,1],...
1335+ // vec2 = [2,2,2,2],...
1336+ // vec3 = [3,3,3,3],...
1337+ // 02 = [0,2,0,2,0,2,0,2],...
1338+ // 13 = [1,3,1,3,1,3,1,3],...
1339+ // 0213 = [0,1,2,3,...],...
1340+ Value interleave02 = rewriter.create <vector::InterleaveOp>(loc, vec0, vec2);
1341+ Value interleave13 = rewriter.create <vector::InterleaveOp>(loc, vec1, vec3);
1342+ return rewriter.create <vector::InterleaveOp>(loc, interleave02, interleave13);
1343+ }
1344+
12431345// / Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1244- // / ops that take advantage of high-level information to avoid leaving LLVM to
1245- // / scramble with peephole optimizations.
1346+ // / ops to avoid leaving LLVM to scramble with peephole optimizations.
12461347static Value rewriteI8ToI4Trunc (PatternRewriter &rewriter, Location loc,
12471348 Value srcValue) {
12481349 VectorType srcVecType = cast<VectorType>(srcValue.getType ());
@@ -1443,13 +1544,19 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
14431544 return failure ();
14441545
14451546 // Perform the rewrite.
1547+ Location loc = conversionOp.getLoc ();
1548+ const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
1549+ : extractNBitsPerByteAndExtendToI8;
14461550 Value subByteExt;
1447- if (isSigned) {
1448- subByteExt =
1449- rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1450- } else {
1451- subByteExt =
1452- rewriteI4ToI8UnsignedExt (rewriter, conversionOp.getLoc (), srcValue);
1551+ switch (srcVecType.getElementType ().getIntOrFloatBitWidth ()) {
1552+ case 2 :
1553+ subByteExt = rewriteI2ToI8Ext (rewriter, loc, srcValue, extFn);
1554+ break ;
1555+ case 4 :
1556+ subByteExt = rewriteI4ToI8Ext (rewriter, loc, srcValue, extFn);
1557+ break ;
1558+ default :
1559+ return failure ();
14531560 }
14541561
14551562 // Finalize the rewrite.
@@ -1490,6 +1597,10 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
14901597 if (failed (commonConversionPrecondition (rewriter, srcVecType, truncOp)))
14911598 return failure ();
14921599
1600+ // TODO: Add support for truncating to i2.
1601+ if (dstVecType.getElementType ().getIntOrFloatBitWidth () == 2 )
1602+ return failure ();
1603+
14931604 // Check general alignment preconditions. We invert the src/dst type order
14941605 // to reuse the existing precondition logic.
14951606 if (failed (alignedConversionPrecondition (rewriter, dstVecType, srcVecType,
0 commit comments