@@ -1183,32 +1183,41 @@ Value BitCastRewriter::genericRewriteStep(
11831183 return runningResult;
11841184}
11851185
1186- // / takes a aligned subByte vector as Input and bitcasts it to a vector of i8.
1186+ // / Bitcasts the aligned `subByteVec` vector to a vector of i8.
1187+ // / Where aligned means it satisfies the alignedConversionPreconditions.
11871188// /
11881189// / Example:
11891190// / vector<16x16xi2> -> vector<16x2xi8>
11901191// / vector<16x16xi4> -> vector<16x4xi8>
11911192static Value bitcastSubByteVectorToI8 (PatternRewriter &rewriter, Location loc,
1192- Value srcValue ) {
1193- auto srcVecType = cast<VectorType>(srcValue .getType ());
1193+ Value subByteVec ) {
1194+ auto srcVecType = cast<VectorType>(subByteVec .getType ());
11941195 int64_t srcBitwidth = srcVecType.getElementType ().getIntOrFloatBitWidth ();
11951196 assert (8 % srcBitwidth == 0 && " Invalid source bitwidth" );
11961197 int64_t bitwidthFactor = 8 / srcBitwidth;
11971198 SmallVector<int64_t > vecShape (srcVecType.getShape ());
1198- // adjust last dimension of the vector so the total size remains the same.
1199+ // Adjust last dimension of the vector, so the total size remains the same.
11991200 vecShape.back () = vecShape.back () / bitwidthFactor;
12001201 auto i8VecType = VectorType::get (vecShape, rewriter.getI8Type ());
1201- return rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue );
1202+ return rewriter.create <vector::BitCastOp>(loc, i8VecType, subByteVec );
12021203}
12031204
12041205// / Extracts a signed N-bit sequence from each element of an 8-bit vector,
12051206// / starting at the specified bit index.
1207+ // / The `bitIdx` starts at 0 from the LSB and moves to the left.
12061208// /
1207- // / Example:
1209+ // / Example for a single element :
12081210// / extract numBits=2 starting at bitIdx=2
1209- // / src = [0101|11|10]
1210- // / shl = src << 4 -> [11100000]
1211- // / result = shl >> 6 -> [11111111]
1211+ // / src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
1212+ // / indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1213+ // / target = [. . . . ^ ^ . .]
1214+ // /
1215+ // / The target sequence is [11](decimal=-1) as signed 2-bit integer.
1216+ // / So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
1217+ // /
1218+ // / src = [01 01 11 10]
1219+ // / shl = arith.shl(src, 4) -> [11 10 00 00]
1220+ // / result = arith.shrsi(shl, 6) -> [11 11 11 11]
12121221static Value extractNBitsFromVectorSigned (PatternRewriter &rewriter,
12131222 Location loc, Value src, int bitIdx,
12141223 int numBits) {
@@ -1232,13 +1241,21 @@ static Value extractNBitsFromVectorSigned(PatternRewriter &rewriter,
12321241
12331242// / Extracts an unsigned N-bit sequence from each element of an 8-bit vector,
12341243// / starting at the specified bit index.
1244+ // / The `bitIdx` starts at 0 from the LSB and moves to the left.
12351245// /
1236- // / Example:
1246+ // / Example for a single element :
12371247// / extract numBits=2 starting at bitIdx=2
1238- // / src = [0101|10|10]
1239- // / mask = [00000011]
1240- // / shr = src >> 6 = [00010110]
1241- // / result = shr & mask = [00000010]
1248+ // / src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
1249+ // / indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1250+ // / target = [. . . . ^ ^ . .]
1251+ // /
1252+ // / The target sequence is [10](decimal=2) as unsigned 2-bit integer.
1253+ // / So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
1254+ // /
1255+ // / src = [01 01 10 10]
1256+ // / mask = [00 00 00 11]
1257+ // / shr = arith.shrui(src, 2) = [00 01 01 10]
1258+ // / result = arith.andi(shr, mask) = [00 00 00 10]
12421259static Value extractNBitsFromVectorUnsinged (PatternRewriter &rewriter,
12431260 Location loc, Value src, int bitIdx,
12441261 int numBits) {
0 commit comments