@@ -1179,112 +1179,83 @@ Value BitCastRewriter::genericRewriteStep(
11791179 return runningResult;
11801180}
11811181
1182- // / Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
1183- // / bitwise ops that take advantage of high-level information to avoid leaving
1184- // / LLVM to scramble with peephole optimizations.
1185- static Value rewriteI4ToI8SignedExt (PatternRewriter &rewriter, Location loc,
1186- Value srcValue) {
1182+ Value bitcastSubByteVectorToI8 (PatternRewriter &rewriter, Location loc,
1183+ Value srcValue) {
11871184 VectorType srcVecType = cast<VectorType>(srcValue.getType ());
1188- assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
1189- " Expected i4 type" );
1190-
1191- // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1185+ int64_t srcBitwidth = srcVecType.getElementType ().getIntOrFloatBitWidth ();
1186+ assert (srcBitwidth % 8 != 0 && " Invalid source bitwidth" );
1187+ int64_t bitwidthFactor = 8 / srcBitwidth;
11921188 SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
1193- constexpr int64_t i4Toi8BitwidthFactor = 2 ;
1194- i8VecShape.back () = i8VecShape.back () / i4Toi8BitwidthFactor;
1189+ i8VecShape.back () = i8VecShape.back () / bitwidthFactor;
11951190 auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
1196- Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
1191+ return rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
1192+ }
11971193
1198- // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
1199- // byte are place in one vector and the high i4 elements in another vector.
1200- constexpr int8_t bitsToShift = 4 ;
1201- auto shiftValues = rewriter.create <arith::ConstantOp>(
1202- loc, DenseElementsAttr::get (i8VecType, bitsToShift));
1203- Value shl = rewriter.create <arith::ShLIOp>(loc, i8Vector, shiftValues);
1204- Value low = rewriter.create <arith::ShRSIOp>(loc, shl, shiftValues);
1205- Value high = rewriter.create <arith::ShRSIOp>(loc, i8Vector, shiftValues);
1194+ // / Extracts a signed N-bit sequence from each element of an 8-bit vector,
1195+ // / starting at the specified bit index.
1196+ Value extractNBitsFromVectorSigned (PatternRewriter &rewriter, Location loc,
1197+ Value src, int bitIdx, int numBits) {
1198+ assert (bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1199+ " Invalid bitIdx range" );
1200+ auto srcType = cast<VectorType>(src.getType ());
1201+ Value shl = src;
1202+ int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1203+ if (bitsToShiftLeft != 0 ) {
1204+ Value shiftLeftValues = rewriter.create <arith::ConstantOp>(
1205+ loc, DenseElementsAttr::get (srcType, bitsToShiftLeft));
1206+ shl = rewriter.create <arith::ShLIOp>(loc, src, shiftLeftValues);
1207+ }
12061208
1207- // 3. Interleave low and high i8 elements.
1208- return rewriter.create <vector::InterleaveOp>(loc, low, high);
1209+ int8_t bitsToShiftRight = 8 - numBits;
1210+ Value shiftRightValues = rewriter.create <arith::ConstantOp>(
1211+ loc, DenseElementsAttr::get (srcType, bitsToShiftRight));
1212+ Value shr = rewriter.create <arith::ShRSIOp>(loc, shl, shiftRightValues);
1213+ return shr;
12091214}
12101215
1211- // / Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
1212- // / bitwise ops that take advantage of high-level information to avoid leaving
1213- // / LLVM to scramble with peephole optimizations.
1214- static Value rewriteI4ToI8UnsignedExt (PatternRewriter &rewriter, Location loc,
1215- Value srcValue) {
1216+ // / Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
1217+ // / bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1218+ static Value rewriteI4ToI8SignedExt (PatternRewriter &rewriter, Location loc,
1219+ Value srcValue) {
12161220 VectorType srcVecType = cast<VectorType>(srcValue.getType ());
12171221 assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
12181222 " Expected i4 type" );
12191223
12201224 // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1221- SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
1222- constexpr int64_t i4Toi8BitwidthFactor = 2 ;
1223- i8VecShape.back () = i8VecShape.back () / i4Toi8BitwidthFactor;
1224- auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
1225- Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
1225+ Value i8Vector = bitcastSubByteVectorToI8 (rewriter, loc, srcValue);
12261226
1227- // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
1228- // byte are placed in one vector and the high i4 elements in another vector.
1229- constexpr uint8_t lowBitsMask = 15 ; // Equivalent to [00001111] bit mask
1230- auto lowBitsMaskValues = rewriter.create <arith::ConstantOp>(
1231- loc, DenseElementsAttr::get (i8VecType, lowBitsMask));
1232- Value low = rewriter.create <arith::AndIOp>(loc, i8VecType, i8Vector,
1233- lowBitsMaskValues);
1234- constexpr int8_t highBitsToShift = 4 ;
1235- auto highShiftValues = rewriter.create <arith::ConstantOp>(
1236- loc, DenseElementsAttr::get (i8VecType, highBitsToShift));
1237- Value high = rewriter.create <arith::ShRUIOp>(loc, i8Vector, highShiftValues);
1227+ // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
1228+ // byte are place in one vector and the high i4 elements in another vector.
1229+ Value low = extractNBitsFromVectorSigned (rewriter, loc, i8Vector, 0 , 4 );
1230+ Value high = extractNBitsFromVectorSigned (rewriter, loc, i8Vector, 4 , 4 );
12381231
12391232 // 3. Interleave low and high i8 elements.
12401233 return rewriter.create <vector::InterleaveOp>(loc, low, high);
12411234}
12421235
12431236// / Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
1244- // / bitwise ops that take advantage of high-level information to avoid leaving
1245- // / LLVM to scramble with peephole optimizations.
1237+ // / bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
12461238static Value rewriteI2ToI8SignedExt (PatternRewriter &rewriter, Location loc,
12471239 Value srcValue) {
12481240 VectorType srcVecType = cast<VectorType>(srcValue.getType ());
12491241 assert (srcVecType.getElementType ().isSignlessInteger (2 ) &&
12501242 " Expected i2 type" );
12511243
12521244 // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1253- SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
1254- constexpr int64_t i2Toi8BitwidthFactor = 4 ;
1255- i8VecShape.back () = i8VecShape.back () / i2Toi8BitwidthFactor;
1256- auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
1257- Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
1245+ Value i8Vector = bitcastSubByteVectorToI8 (rewriter, loc, srcValue);
12581246
1247+ // 2. Extract each i2 element using shifts
12591248 // Element 0 (bits 0-1)
1260- constexpr int8_t shiftConst6 = 6 ;
1261- auto shiftAttr6 = DenseElementsAttr::get (i8VecType, shiftConst6);
1262- auto shiftValues6 = rewriter.create <arith::ConstantOp>(loc, shiftAttr6);
1263- Value shl0 = rewriter.create <arith::ShLIOp>(loc, i8Vector, shiftValues6);
1264- Value elem0 = rewriter.create <arith::ShRSIOp>(loc, shl0, shiftValues6);
1265-
1249+ Value elem0 = extractNBitsFromVectorSigned (rewriter, loc, i8Vector, 0 , 2 );
12661250 // Element 1 (bits 2-3)
1267- constexpr int8_t shiftConst4 = 4 ;
1268- auto shiftAttr4 = DenseElementsAttr::get (i8VecType, shiftConst4);
1269- auto shiftValues4 = rewriter.create <arith::ConstantOp>(loc, shiftAttr4);
1270- Value shl1 = rewriter.create <arith::ShLIOp>(loc, i8Vector, shiftValues4);
1271- Value elem1 = rewriter.create <arith::ShRSIOp>(loc, shl1, shiftValues6);
1272-
1251+ Value elem1 = extractNBitsFromVectorSigned (rewriter, loc, i8Vector, 2 , 2 );
12731252 // Element 2 (bits 4-5)
1274- constexpr int8_t shiftConst2 = 2 ;
1275- auto shiftAttr2 = DenseElementsAttr::get (i8VecType, shiftConst2);
1276- auto shiftValues2 = rewriter.create <arith::ConstantOp>(loc, shiftAttr2);
1277- Value shl2 = rewriter.create <arith::ShLIOp>(loc, i8Vector, shiftValues2);
1278- Value elem2 = rewriter.create <arith::ShRSIOp>(loc, shl2, shiftValues6);
1279-
1253+ Value elem2 = extractNBitsFromVectorSigned (rewriter, loc, i8Vector, 4 , 2 );
12801254 // Element 3 (bits 6-7)
1281- Value elem3 = rewriter. create <arith::ShRSIOp>( loc, i8Vector, shiftValues6 );
1255+ Value elem3 = extractNBitsFromVectorSigned (rewriter, loc, i8Vector, 6 , 2 );
12821256
1283- // interleave all 4 elements by first interleaving even elements and then odd
1284- // elem0 = [0,0,0,0]
1285- // elem1 = [1,1,1,1]
1286- // elem2 = [2,2,2,2]
1287- // elem3 = [3,3,3,3]
1257+ // 3. Interleave all 4 elements by first interleaving even elements and then
1258+ // odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
12881259 // 02 = [0,2,0,2]
12891260 // 13 = [1,3,1,3]
12901261 // 0213 = [0,1,2,3]
@@ -1293,56 +1264,72 @@ static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
12931264 return rewriter.create <vector::InterleaveOp>(loc, interleave02, interleave13);
12941265}
12951266
1267+ // / Extracts an unsigned N-bit sequence from each element of an 8-bit vector,
1268+ // / starting at the specified bit index.
1269+ Value extractNBitsFromVectorUnsinged (PatternRewriter &rewriter, Location loc,
1270+ Value src, int bitIdx, int numBits) {
1271+ assert (bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1272+ " Invalid bitIdx range" );
1273+ auto srcType = cast<VectorType>(src.getType ());
1274+ int8_t bitsToShiftRight = bitIdx;
1275+ Value shr = src;
1276+ if (bitsToShiftRight != 0 ) {
1277+ Value shiftRightValues = rewriter.create <arith::ConstantOp>(
1278+ loc, DenseElementsAttr::get (srcType, bitsToShiftRight));
1279+ shr = rewriter.create <arith::ShRUIOp>(loc, src, shiftRightValues);
1280+ }
1281+ if (bitIdx + numBits == 8 ) {
1282+ return shr;
1283+ }
1284+ uint8_t lowBitsMask = (1 << numBits) - 1 ;
1285+ Value lowBitsMaskValues = rewriter.create <arith::ConstantOp>(
1286+ loc, DenseElementsAttr::get (srcType, lowBitsMask));
1287+ return rewriter.create <arith::AndIOp>(loc, shr, lowBitsMaskValues);
1288+ }
1289+
1290+ // / Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
1291+ // / bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1292+ static Value rewriteI4ToI8UnsignedExt (PatternRewriter &rewriter, Location loc,
1293+ Value srcValue) {
1294+ VectorType srcVecType = cast<VectorType>(srcValue.getType ());
1295+ assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
1296+ " Expected i4 type" );
1297+
1298+ // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1299+ Value i8Vector = bitcastSubByteVectorToI8 (rewriter, loc, srcValue);
1300+
1301+ // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
1302+ // byte are placed in one vector and the high i4 elements in another vector.
1303+ Value low = extractNBitsFromVectorUnsinged (rewriter, loc, i8Vector, 0 , 4 );
1304+ Value high = extractNBitsFromVectorUnsinged (rewriter, loc, i8Vector, 4 , 4 );
1305+
1306+ // 3. Interleave low and high i8 elements.
1307+ return rewriter.create <vector::InterleaveOp>(loc, low, high);
1308+ }
1309+
12961310// / Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
1297- // / bitwise ops that take advantage of high-level information to avoid leaving
1298- // / LLVM to scramble with peephole optimizations.
1311+ // / bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
12991312static Value rewriteI2ToI8UnsignedExt (PatternRewriter &rewriter, Location loc,
13001313 Value srcValue) {
13011314 VectorType srcVecType = cast<VectorType>(srcValue.getType ());
13021315 assert (srcVecType.getElementType ().isSignlessInteger (2 ) &&
13031316 " Expected i2 type" );
13041317
13051318 // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1306- SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
1307- constexpr int64_t i2Toi8BitwidthFactor = 4 ;
1308- i8VecShape.back () = i8VecShape.back () / i2Toi8BitwidthFactor;
1309- auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
1310- Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
1319+ Value i8Vector = bitcastSubByteVectorToI8 (rewriter, loc, srcValue);
13111320
13121321 // 2. Extract each i2 element using shifts and masks
1313- constexpr uint8_t mask = 3 ; // Mask for 2 bits: [0000 0011]
1314- auto maskAttr = DenseElementsAttr::get (i8VecType, mask);
1315- auto maskValues = rewriter.create <arith::ConstantOp>(loc, maskAttr);
1316-
13171322 // Element 0 (bits 0-1)
1318- Value elem0 = rewriter.create <arith::AndIOp>(loc, i8Vector, maskValues);
1319-
1323+ Value elem0 = extractNBitsFromVectorUnsinged (rewriter, loc, i8Vector, 0 , 2 );
13201324 // Element 1 (bits 2-3)
1321- constexpr int8_t shift1 = 2 ;
1322- auto shiftAttr1 = DenseElementsAttr::get (i8VecType, shift1);
1323- auto shiftValues1 = rewriter.create <arith::ConstantOp>(loc, shiftAttr1);
1324- Value shifted1 = rewriter.create <arith::ShRUIOp>(loc, i8Vector, shiftValues1);
1325- Value elem1 = rewriter.create <arith::AndIOp>(loc, shifted1, maskValues);
1326-
1325+ Value elem1 = extractNBitsFromVectorUnsinged (rewriter, loc, i8Vector, 2 , 2 );
13271326 // Element 2 (bits 4-5)
1328- constexpr int8_t shift2 = 4 ;
1329- auto shiftAttr2 = DenseElementsAttr::get (i8VecType, shift2);
1330- auto shiftValues2 = rewriter.create <arith::ConstantOp>(loc, shiftAttr2);
1331- Value shifted2 = rewriter.create <arith::ShRUIOp>(loc, i8Vector, shiftValues2);
1332- Value elem2 = rewriter.create <arith::AndIOp>(loc, shifted2, maskValues);
1333-
1327+ Value elem2 = extractNBitsFromVectorUnsinged (rewriter, loc, i8Vector, 4 , 2 );
13341328 // Element 3 (bits 6-7)
1335- constexpr int8_t shift3 = 6 ;
1336- auto shiftAttr3 = DenseElementsAttr::get (i8VecType, shift3);
1337- auto shiftValues3 = rewriter.create <arith::ConstantOp>(loc, shiftAttr3);
1338- Value shifted3 = rewriter.create <arith::ShRUIOp>(loc, i8Vector, shiftValues3);
1339- Value elem3 = rewriter.create <arith::AndIOp>(loc, shifted3, maskValues);
1340-
1341- // interleave all 4 elements by first interleaving even elements and then odd
1342- // elem0 = [0,0,0,0]
1343- // elem1 = [1,1,1,1]
1344- // elem2 = [2,2,2,2]
1345- // elem3 = [3,3,3,3]
1329+ Value elem3 = extractNBitsFromVectorUnsinged (rewriter, loc, i8Vector, 6 , 2 );
1330+
1331+ // 3. Interleave all 4 elements by first interleaving even elements and then
1332+ // odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
13461333 // 02 = [0,2,0,2]
13471334 // 13 = [1,3,1,3]
13481335 // 0213 = [0,1,2,3]
@@ -1352,8 +1339,7 @@ static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
13521339}
13531340
13541341// / Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1355- // / ops that take advantage of high-level information to avoid leaving LLVM to
1356- // / scramble with peephole optimizations.
1342+ // / ops to avoid leaving LLVM to scramble with peephole optimizations.
13571343static Value rewriteI8ToI4Trunc (PatternRewriter &rewriter, Location loc,
13581344 Value srcValue) {
13591345 VectorType srcVecType = cast<VectorType>(srcValue.getType ());
@@ -1556,20 +1542,30 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
15561542 // Perform the rewrite.
15571543 Value subByteExt;
15581544 if (isSigned) {
1559- if (srcVecType.getElementType ().getIntOrFloatBitWidth () == 2 )
1545+ switch (srcVecType.getElementType ().getIntOrFloatBitWidth ()) {
1546+ case 2 :
15601547 subByteExt =
15611548 rewriteI2ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1562- else {
1549+ break ;
1550+ case 4 :
15631551 subByteExt =
15641552 rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1553+ break ;
1554+ default :
1555+ return failure ();
15651556 }
15661557 } else {
1567- if (srcVecType.getElementType ().getIntOrFloatBitWidth () == 2 ) {
1558+ switch (srcVecType.getElementType ().getIntOrFloatBitWidth ()) {
1559+ case 2 :
15681560 subByteExt =
15691561 rewriteI2ToI8UnsignedExt (rewriter, conversionOp.getLoc (), srcValue);
1570- } else {
1562+ break ;
1563+ case 4 :
15711564 subByteExt =
15721565 rewriteI4ToI8UnsignedExt (rewriter, conversionOp.getLoc (), srcValue);
1566+ break ;
1567+ default :
1568+ return failure ();
15731569 }
15741570 }
15751571
@@ -1611,16 +1607,16 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
16111607 if (failed (commonConversionPrecondition (rewriter, srcVecType, truncOp)))
16121608 return failure ();
16131609
1610+ // TODO: Add support for truncating to i2.
1611+ if (dstVecType.getElementType ().getIntOrFloatBitWidth () == 2 )
1612+ return failure ();
1613+
16141614 // Check general alignment preconditions. We invert the src/dst type order
16151615 // to reuse the existing precondition logic.
16161616 if (failed (alignedConversionPrecondition (rewriter, dstVecType, srcVecType,
16171617 truncOp)))
16181618 return failure ();
16191619
1620- // not supported currently.
1621- if (dstVecType.getElementType ().getIntOrFloatBitWidth () == 2 )
1622- return failure ();
1623-
16241620 // Create a new iX -> i8 truncation op.
16251621 Location loc = truncOp.getLoc ();
16261622 auto i8VecType = srcVecType.cloneWith (std::nullopt , rewriter.getI8Type ());
0 commit comments