@@ -1090,10 +1090,14 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
10901090 unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth ();
10911091 unsigned dstElemBitwidth = dstType.getElementTypeBitWidth ();
10921092
1093- // Only {s}i4/i2 -> (size_of({{s}i/f}) >= 8) are supported for now.
1094- if ((srcElemBitwidth != 4 && srcElemBitwidth != 2 ) || dstElemBitwidth < 8 ||
1095- (dstElemBitwidth % srcElemBitwidth) != 0 )
1096- return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
1093+ if (dstElemBitwidth < 8 )
1094+ return rewriter.notifyMatchFailure (
1095+ op, " the bitwidth of dstType must be greater than or equal to 8" );
1096+ if (dstElemBitwidth % srcElemBitwidth != 0 )
1097+ return rewriter.notifyMatchFailure (op, " unaligned cases are not supported" );
1098+ if (srcElemBitwidth != 2 && srcElemBitwidth != 4 )
1099+ return rewriter.notifyMatchFailure (
1100+ op, " only src bitwidth of 2 or 4 is supported at this moment" );
10971101
10981102 const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
10991103 if ((subByteVecType.getShape ().back () % numSrcElemsPerDestElem) != 0 )
@@ -1179,22 +1183,35 @@ Value BitCastRewriter::genericRewriteStep(
11791183 return runningResult;
11801184}
11811185
1182- Value bitcastSubByteVectorToI8 (PatternRewriter &rewriter, Location loc,
1183- Value srcValue) {
1184- VectorType srcVecType = cast<VectorType>(srcValue.getType ());
1186+ // / takes a aligned subByte vector as Input and bitcasts it to a vector of i8.
1187+ // /
1188+ // / Example:
1189+ // / vector<16x16xi2> -> vector<16x2xi8>
1190+ // / vector<16x16xi4> -> vector<16x4xi8>
1191+ static Value bitcastSubByteVectorToI8 (PatternRewriter &rewriter, Location loc,
1192+ Value srcValue) {
1193+ auto srcVecType = cast<VectorType>(srcValue.getType ());
11851194 int64_t srcBitwidth = srcVecType.getElementType ().getIntOrFloatBitWidth ();
1186- assert (srcBitwidth % 8 ! = 0 && " Invalid source bitwidth" );
1195+ assert (8 % srcBitwidth = = 0 && " Invalid source bitwidth" );
11871196 int64_t bitwidthFactor = 8 / srcBitwidth;
1188- SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
1189- i8VecShape.back () = i8VecShape.back () / bitwidthFactor;
1190- auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
1197+ SmallVector<int64_t > vecShape (srcVecType.getShape ());
1198+ // adjust last dimension of the vector so the total size remains the same.
1199+ vecShape.back () = vecShape.back () / bitwidthFactor;
1200+ auto i8VecType = VectorType::get (vecShape, rewriter.getI8Type ());
11911201 return rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
11921202}
11931203
11941204// / Extracts a signed N-bit sequence from each element of an 8-bit vector,
11951205// / starting at the specified bit index.
1196- Value extractNBitsFromVectorSigned (PatternRewriter &rewriter, Location loc,
1197- Value src, int bitIdx, int numBits) {
1206+ // /
1207+ // / Example:
1208+ // / extract numBits=2 starting at bitIdx=2
1209+ // / src = [0101|11|10]
1210+ // / shl = src << 4 -> [11100000]
1211+ // / result = shl >> 6 -> [11111111]
1212+ static Value extractNBitsFromVectorSigned (PatternRewriter &rewriter,
1213+ Location loc, Value src, int bitIdx,
1214+ int numBits) {
11981215 assert (bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
11991216 " Invalid bitIdx range" );
12001217 auto srcType = cast<VectorType>(src.getType ());
@@ -1213,61 +1230,18 @@ Value extractNBitsFromVectorSigned(PatternRewriter &rewriter, Location loc,
12131230 return shr;
12141231}
12151232
1216- // / Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
1217- // / bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1218- static Value rewriteI4ToI8SignedExt (PatternRewriter &rewriter, Location loc,
1219- Value srcValue) {
1220- VectorType srcVecType = cast<VectorType>(srcValue.getType ());
1221- assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
1222- " Expected i4 type" );
1223-
1224- // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1225- Value i8Vector = bitcastSubByteVectorToI8 (rewriter, loc, srcValue);
1226-
1227- // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
1228- // byte are place in one vector and the high i4 elements in another vector.
1229- Value low = extractNBitsFromVectorSigned (rewriter, loc, i8Vector, 0 , 4 );
1230- Value high = extractNBitsFromVectorSigned (rewriter, loc, i8Vector, 4 , 4 );
1231-
1232- // 3. Interleave low and high i8 elements.
1233- return rewriter.create <vector::InterleaveOp>(loc, low, high);
1234- }
1235-
1236- // / Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
1237- // / bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1238- static Value rewriteI2ToI8SignedExt (PatternRewriter &rewriter, Location loc,
1239- Value srcValue) {
1240- VectorType srcVecType = cast<VectorType>(srcValue.getType ());
1241- assert (srcVecType.getElementType ().isSignlessInteger (2 ) &&
1242- " Expected i2 type" );
1243-
1244- // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1245- Value i8Vector = bitcastSubByteVectorToI8 (rewriter, loc, srcValue);
1246-
1247- // 2. Extract each i2 element using shifts
1248- // Element 0 (bits 0-1)
1249- Value elem0 = extractNBitsFromVectorSigned (rewriter, loc, i8Vector, 0 , 2 );
1250- // Element 1 (bits 2-3)
1251- Value elem1 = extractNBitsFromVectorSigned (rewriter, loc, i8Vector, 2 , 2 );
1252- // Element 2 (bits 4-5)
1253- Value elem2 = extractNBitsFromVectorSigned (rewriter, loc, i8Vector, 4 , 2 );
1254- // Element 3 (bits 6-7)
1255- Value elem3 = extractNBitsFromVectorSigned (rewriter, loc, i8Vector, 6 , 2 );
1256-
1257- // 3. Interleave all 4 elements by first interleaving even elements and then
1258- // odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
1259- // 02 = [0,2,0,2]
1260- // 13 = [1,3,1,3]
1261- // 0213 = [0,1,2,3]
1262- Value interleave02 = rewriter.create <vector::InterleaveOp>(loc, elem0, elem2);
1263- Value interleave13 = rewriter.create <vector::InterleaveOp>(loc, elem1, elem3);
1264- return rewriter.create <vector::InterleaveOp>(loc, interleave02, interleave13);
1265- }
1266-
12671233// / Extracts an unsigned N-bit sequence from each element of an 8-bit vector,
12681234// / starting at the specified bit index.
1269- Value extractNBitsFromVectorUnsinged (PatternRewriter &rewriter, Location loc,
1270- Value src, int bitIdx, int numBits) {
1235+ // /
1236+ // / Example:
1237+ // / extract numBits=2 starting at bitIdx=2
1238+ // / src = [0101|10|10]
1239+ // / mask = [00000011]
1240+ // / shr = src >> 6 = [00010110]
1241+ // / result = shr & mask = [00000010]
1242+ static Value extractNBitsFromVectorUnsinged (PatternRewriter &rewriter,
1243+ Location loc, Value src, int bitIdx,
1244+ int numBits) {
12711245 assert (bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
12721246 " Invalid bitIdx range" );
12731247 auto srcType = cast<VectorType>(src.getType ());
@@ -1287,49 +1261,56 @@ Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter, Location loc,
12871261 return rewriter.create <arith::AndIOp>(loc, shr, lowBitsMaskValues);
12881262}
12891263
1290- // / Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
1264+ using ExtractNBitsFn =
1265+ std::function<Value(PatternRewriter &, Location, Value, int , int )>;
1266+
1267+ // / Rewrite the i4 -> i8 extension into a sequence of shuffles and
12911268// / bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1292- static Value rewriteI4ToI8UnsignedExt (PatternRewriter &rewriter, Location loc,
1293- Value srcValue ) {
1294- VectorType srcVecType = cast<VectorType>(srcValue.getType ());
1269+ static Value rewriteI4ToI8Ext (PatternRewriter &rewriter, Location loc,
1270+ Value srcValue, const ExtractNBitsFn &extFn ) {
1271+ auto srcVecType = cast<VectorType>(srcValue.getType ());
12951272 assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
12961273 " Expected i4 type" );
12971274
12981275 // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
12991276 Value i8Vector = bitcastSubByteVectorToI8 (rewriter, loc, srcValue);
13001277
1301- // 2 Extend the i4 elements using shifts & masking . Low i4 elements of each
1302- // byte are placed in one vector and the high i4 elements in another vector.
1303- Value low = extractNBitsFromVectorUnsinged (rewriter, loc, i8Vector, 0 , 4 );
1304- Value high = extractNBitsFromVectorUnsinged (rewriter, loc, i8Vector, 4 , 4 );
1278+ // 2. Extend i4 elements to i8 elements . Low i4 elemens of each
1279+ // byte are place in one vector and the high i4 elements in another vector.
1280+ Value low = extFn (rewriter, loc, i8Vector, 0 , 4 );
1281+ Value high = extFn (rewriter, loc, i8Vector, 4 , 4 );
13051282
13061283 // 3. Interleave low and high i8 elements.
13071284 return rewriter.create <vector::InterleaveOp>(loc, low, high);
13081285}
13091286
1310- // / Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
1287+ // / Rewrite the i2 -> i8 extension into a sequence of shuffles and
13111288// / bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1312- static Value rewriteI2ToI8UnsignedExt (PatternRewriter &rewriter, Location loc,
1313- Value srcValue ) {
1289+ static Value rewriteI2ToI8Ext (PatternRewriter &rewriter, Location loc,
1290+ Value srcValue, const ExtractNBitsFn &extFn ) {
13141291 VectorType srcVecType = cast<VectorType>(srcValue.getType ());
13151292 assert (srcVecType.getElementType ().isSignlessInteger (2 ) &&
13161293 " Expected i2 type" );
13171294
13181295 // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
13191296 Value i8Vector = bitcastSubByteVectorToI8 (rewriter, loc, srcValue);
13201297
1321- // 2. Extract each i2 element using shifts and masks
1298+ // 2. Extract each i2 element
13221299 // Element 0 (bits 0-1)
1323- Value elem0 = extractNBitsFromVectorUnsinged (rewriter, loc, i8Vector, 0 , 2 );
1300+ Value elem0 = extFn (rewriter, loc, i8Vector, 0 , 2 );
13241301 // Element 1 (bits 2-3)
1325- Value elem1 = extractNBitsFromVectorUnsinged (rewriter, loc, i8Vector, 2 , 2 );
1302+ Value elem1 = extFn (rewriter, loc, i8Vector, 2 , 2 );
13261303 // Element 2 (bits 4-5)
1327- Value elem2 = extractNBitsFromVectorUnsinged (rewriter, loc, i8Vector, 4 , 2 );
1304+ Value elem2 = extFn (rewriter, loc, i8Vector, 4 , 2 );
13281305 // Element 3 (bits 6-7)
1329- Value elem3 = extractNBitsFromVectorUnsinged (rewriter, loc, i8Vector, 6 , 2 );
1330-
1331- // 3. Interleave all 4 elements by first interleaving even elements and then
1332- // odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
1306+ Value elem3 = extFn (rewriter, loc, i8Vector, 6 , 2 );
1307+
1308+ // 3. Interleave all 4 elements by first interleaving
1309+ // even elements and then odd
1310+ // elem0 = [0,0,0,0]
1311+ // elem1 = [1,1,1,1]
1312+ // elem2 = [2,2,2,2]
1313+ // elem3 = [3,3,3,3]
13331314 // 02 = [0,2,0,2]
13341315 // 13 = [1,3,1,3]
13351316 // 0213 = [0,1,2,3]
@@ -1540,33 +1521,19 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
15401521 return failure ();
15411522
15421523 // Perform the rewrite.
1524+ Location loc = conversionOp.getLoc ();
1525+ const auto &extFn = isSigned ? extractNBitsFromVectorSigned
1526+ : extractNBitsFromVectorUnsinged;
15431527 Value subByteExt;
1544- if (isSigned) {
1545- switch (srcVecType.getElementType ().getIntOrFloatBitWidth ()) {
1546- case 2 :
1547- subByteExt =
1548- rewriteI2ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1549- break ;
1550- case 4 :
1551- subByteExt =
1552- rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1553- break ;
1554- default :
1555- return failure ();
1556- }
1557- } else {
1558- switch (srcVecType.getElementType ().getIntOrFloatBitWidth ()) {
1559- case 2 :
1560- subByteExt =
1561- rewriteI2ToI8UnsignedExt (rewriter, conversionOp.getLoc (), srcValue);
1562- break ;
1563- case 4 :
1564- subByteExt =
1565- rewriteI4ToI8UnsignedExt (rewriter, conversionOp.getLoc (), srcValue);
1566- break ;
1567- default :
1568- return failure ();
1569- }
1528+ switch (srcVecType.getElementType ().getIntOrFloatBitWidth ()) {
1529+ case 2 :
1530+ subByteExt = rewriteI2ToI8Ext (rewriter, loc, srcValue, extFn);
1531+ break ;
1532+ case 4 :
1533+ subByteExt = rewriteI4ToI8Ext (rewriter, loc, srcValue, extFn);
1534+ break ;
1535+ default :
1536+ return failure ();
15701537 }
15711538
15721539 // Finalize the rewrite.
0 commit comments