Skip to content

Commit cfe31bb

Browse files
committed
refactoring
1 parent 303416a commit cfe31bb

File tree

1 file changed

+78
-111
lines changed

1 file changed

+78
-111
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 78 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,10 +1090,14 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
10901090
unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
10911091
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
10921092

1093-
// Only {s}i4/i2 -> (size_of({{s}i/f}) >= 8) are supported for now.
1094-
if ((srcElemBitwidth != 4 && srcElemBitwidth != 2) || dstElemBitwidth < 8 ||
1095-
(dstElemBitwidth % srcElemBitwidth) != 0)
1096-
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
1093+
if (dstElemBitwidth < 8)
1094+
return rewriter.notifyMatchFailure(
1095+
op, "the bitwidth of dstType must be greater than or equal to 8");
1096+
if (dstElemBitwidth % srcElemBitwidth != 0)
1097+
return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
1098+
if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
1099+
return rewriter.notifyMatchFailure(
1100+
op, "only src bitwidth of 2 or 4 is supported at this moment");
10971101

10981102
const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
10991103
if ((subByteVecType.getShape().back() % numSrcElemsPerDestElem) != 0)
@@ -1179,22 +1183,35 @@ Value BitCastRewriter::genericRewriteStep(
11791183
return runningResult;
11801184
}
11811185

1182-
Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
1183-
Value srcValue) {
1184-
VectorType srcVecType = cast<VectorType>(srcValue.getType());
1186+
/// takes a aligned subByte vector as Input and bitcasts it to a vector of i8.
1187+
///
1188+
/// Example:
1189+
/// vector<16x16xi2> -> vector<16x2xi8>
1190+
/// vector<16x16xi4> -> vector<16x4xi8>
1191+
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
1192+
Value srcValue) {
1193+
auto srcVecType = cast<VectorType>(srcValue.getType());
11851194
int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1186-
assert(srcBitwidth % 8 != 0 && "Invalid source bitwidth");
1195+
assert(8 % srcBitwidth == 0 && "Invalid source bitwidth");
11871196
int64_t bitwidthFactor = 8 / srcBitwidth;
1188-
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
1189-
i8VecShape.back() = i8VecShape.back() / bitwidthFactor;
1190-
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
1197+
SmallVector<int64_t> vecShape(srcVecType.getShape());
1198+
// adjust last dimension of the vector so the total size remains the same.
1199+
vecShape.back() = vecShape.back() / bitwidthFactor;
1200+
auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
11911201
return rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
11921202
}
11931203

11941204
/// Extracts a signed N-bit sequence from each element of an 8-bit vector,
11951205
/// starting at the specified bit index.
1196-
Value extractNBitsFromVectorSigned(PatternRewriter &rewriter, Location loc,
1197-
Value src, int bitIdx, int numBits) {
1206+
///
1207+
/// Example:
1208+
/// extract numBits=2 starting at bitIdx=2
1209+
/// src = [0101|11|10]
1210+
/// shl = src << 4 -> [11100000]
1211+
/// result = shl >> 6 -> [11111111]
1212+
static Value extractNBitsFromVectorSigned(PatternRewriter &rewriter,
1213+
Location loc, Value src, int bitIdx,
1214+
int numBits) {
11981215
assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
11991216
"Invalid bitIdx range");
12001217
auto srcType = cast<VectorType>(src.getType());
@@ -1213,61 +1230,18 @@ Value extractNBitsFromVectorSigned(PatternRewriter &rewriter, Location loc,
12131230
return shr;
12141231
}
12151232

1216-
/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
1217-
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1218-
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
1219-
Value srcValue) {
1220-
VectorType srcVecType = cast<VectorType>(srcValue.getType());
1221-
assert(srcVecType.getElementType().isSignlessInteger(4) &&
1222-
"Expected i4 type");
1223-
1224-
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1225-
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1226-
1227-
// 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
1228-
// byte are place in one vector and the high i4 elements in another vector.
1229-
Value low = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 0, 4);
1230-
Value high = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 4, 4);
1231-
1232-
// 3. Interleave low and high i8 elements.
1233-
return rewriter.create<vector::InterleaveOp>(loc, low, high);
1234-
}
1235-
1236-
/// Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
1237-
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1238-
static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
1239-
Value srcValue) {
1240-
VectorType srcVecType = cast<VectorType>(srcValue.getType());
1241-
assert(srcVecType.getElementType().isSignlessInteger(2) &&
1242-
"Expected i2 type");
1243-
1244-
// 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1245-
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1246-
1247-
// 2. Extract each i2 element using shifts
1248-
// Element 0 (bits 0-1)
1249-
Value elem0 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 0, 2);
1250-
// Element 1 (bits 2-3)
1251-
Value elem1 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 2, 2);
1252-
// Element 2 (bits 4-5)
1253-
Value elem2 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 4, 2);
1254-
// Element 3 (bits 6-7)
1255-
Value elem3 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 6, 2);
1256-
1257-
// 3. Interleave all 4 elements by first interleaving even elements and then
1258-
// odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
1259-
// 02 = [0,2,0,2]
1260-
// 13 = [1,3,1,3]
1261-
// 0213 = [0,1,2,3]
1262-
Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
1263-
Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
1264-
return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
1265-
}
1266-
12671233
/// Extracts an unsigned N-bit sequence from each element of an 8-bit vector,
12681234
/// starting at the specified bit index.
1269-
Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter, Location loc,
1270-
Value src, int bitIdx, int numBits) {
1235+
///
1236+
/// Example:
1237+
/// extract numBits=2 starting at bitIdx=2
1238+
/// src = [0101|10|10]
1239+
/// mask = [00000011]
1240+
/// shr = src >> 6 = [00010110]
1241+
/// result = shr & mask = [00000010]
1242+
static Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter,
1243+
Location loc, Value src, int bitIdx,
1244+
int numBits) {
12711245
assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
12721246
"Invalid bitIdx range");
12731247
auto srcType = cast<VectorType>(src.getType());
@@ -1287,49 +1261,56 @@ Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter, Location loc,
12871261
return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
12881262
}
12891263

1290-
/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
1264+
using ExtractNBitsFn =
1265+
std::function<Value(PatternRewriter &, Location, Value, int, int)>;
1266+
1267+
/// Rewrite the i4 -> i8 extension into a sequence of shuffles and
12911268
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1292-
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
1293-
Value srcValue) {
1294-
VectorType srcVecType = cast<VectorType>(srcValue.getType());
1269+
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc,
1270+
Value srcValue, const ExtractNBitsFn &extFn) {
1271+
auto srcVecType = cast<VectorType>(srcValue.getType());
12951272
assert(srcVecType.getElementType().isSignlessInteger(4) &&
12961273
"Expected i4 type");
12971274

12981275
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
12991276
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
13001277

1301-
// 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
1302-
// byte are placed in one vector and the high i4 elements in another vector.
1303-
Value low = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 0, 4);
1304-
Value high = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 4, 4);
1278+
// 2. Extend i4 elements to i8 elements. Low i4 elemens of each
1279+
// byte are place in one vector and the high i4 elements in another vector.
1280+
Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1281+
Value high = extFn(rewriter, loc, i8Vector, 4, 4);
13051282

13061283
// 3. Interleave low and high i8 elements.
13071284
return rewriter.create<vector::InterleaveOp>(loc, low, high);
13081285
}
13091286

1310-
/// Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
1287+
/// Rewrite the i2 -> i8 extension into a sequence of shuffles and
13111288
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1312-
static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
1313-
Value srcValue) {
1289+
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc,
1290+
Value srcValue, const ExtractNBitsFn &extFn) {
13141291
VectorType srcVecType = cast<VectorType>(srcValue.getType());
13151292
assert(srcVecType.getElementType().isSignlessInteger(2) &&
13161293
"Expected i2 type");
13171294

13181295
// 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
13191296
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
13201297

1321-
// 2. Extract each i2 element using shifts and masks
1298+
// 2. Extract each i2 element
13221299
// Element 0 (bits 0-1)
1323-
Value elem0 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 0, 2);
1300+
Value elem0 = extFn(rewriter, loc, i8Vector, 0, 2);
13241301
// Element 1 (bits 2-3)
1325-
Value elem1 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 2, 2);
1302+
Value elem1 = extFn(rewriter, loc, i8Vector, 2, 2);
13261303
// Element 2 (bits 4-5)
1327-
Value elem2 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 4, 2);
1304+
Value elem2 = extFn(rewriter, loc, i8Vector, 4, 2);
13281305
// Element 3 (bits 6-7)
1329-
Value elem3 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 6, 2);
1330-
1331-
// 3. Interleave all 4 elements by first interleaving even elements and then
1332-
// odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
1306+
Value elem3 = extFn(rewriter, loc, i8Vector, 6, 2);
1307+
1308+
// 3. Interleave all 4 elements by first interleaving
1309+
// even elements and then odd
1310+
// elem0 = [0,0,0,0]
1311+
// elem1 = [1,1,1,1]
1312+
// elem2 = [2,2,2,2]
1313+
// elem3 = [3,3,3,3]
13331314
// 02 = [0,2,0,2]
13341315
// 13 = [1,3,1,3]
13351316
// 0213 = [0,1,2,3]
@@ -1540,33 +1521,19 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
15401521
return failure();
15411522

15421523
// Perform the rewrite.
1524+
Location loc = conversionOp.getLoc();
1525+
const auto &extFn = isSigned ? extractNBitsFromVectorSigned
1526+
: extractNBitsFromVectorUnsinged;
15431527
Value subByteExt;
1544-
if (isSigned) {
1545-
switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
1546-
case 2:
1547-
subByteExt =
1548-
rewriteI2ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1549-
break;
1550-
case 4:
1551-
subByteExt =
1552-
rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1553-
break;
1554-
default:
1555-
return failure();
1556-
}
1557-
} else {
1558-
switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
1559-
case 2:
1560-
subByteExt =
1561-
rewriteI2ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
1562-
break;
1563-
case 4:
1564-
subByteExt =
1565-
rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
1566-
break;
1567-
default:
1568-
return failure();
1569-
}
1528+
switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
1529+
case 2:
1530+
subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
1531+
break;
1532+
case 4:
1533+
subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
1534+
break;
1535+
default:
1536+
return failure();
15701537
}
15711538

15721539
// Finalize the rewrite.

0 commit comments

Comments
 (0)