Skip to content

Commit 303416a

Browse files
committed
refactoring:
- extracts repeated code into functions - reorder tests - improve naming
1 parent 99b6785 commit 303416a

File tree

2 files changed

+178
-186
lines changed

2 files changed

+178
-186
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 116 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,112 +1179,83 @@ Value BitCastRewriter::genericRewriteStep(
11791179
return runningResult;
11801180
}
11811181

1182-
/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
1183-
/// bitwise ops that take advantage of high-level information to avoid leaving
1184-
/// LLVM to scramble with peephole optimizations.
1185-
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
1186-
Value srcValue) {
1182+
Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
1183+
Value srcValue) {
11871184
VectorType srcVecType = cast<VectorType>(srcValue.getType());
1188-
assert(srcVecType.getElementType().isSignlessInteger(4) &&
1189-
"Expected i4 type");
1190-
1191-
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1185+
int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1186+
assert(srcBitwidth % 8 != 0 && "Invalid source bitwidth");
1187+
int64_t bitwidthFactor = 8 / srcBitwidth;
11921188
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
1193-
constexpr int64_t i4Toi8BitwidthFactor = 2;
1194-
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
1189+
i8VecShape.back() = i8VecShape.back() / bitwidthFactor;
11951190
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
1196-
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
1191+
return rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
1192+
}
11971193

1198-
// 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
1199-
// byte are place in one vector and the high i4 elements in another vector.
1200-
constexpr int8_t bitsToShift = 4;
1201-
auto shiftValues = rewriter.create<arith::ConstantOp>(
1202-
loc, DenseElementsAttr::get(i8VecType, bitsToShift));
1203-
Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
1204-
Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
1205-
Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
1194+
/// Extracts a signed N-bit sequence from each element of an 8-bit vector,
1195+
/// starting at the specified bit index.
1196+
Value extractNBitsFromVectorSigned(PatternRewriter &rewriter, Location loc,
1197+
Value src, int bitIdx, int numBits) {
1198+
assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1199+
"Invalid bitIdx range");
1200+
auto srcType = cast<VectorType>(src.getType());
1201+
Value shl = src;
1202+
int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1203+
if (bitsToShiftLeft != 0) {
1204+
Value shiftLeftValues = rewriter.create<arith::ConstantOp>(
1205+
loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
1206+
shl = rewriter.create<arith::ShLIOp>(loc, src, shiftLeftValues);
1207+
}
12061208

1207-
// 3. Interleave low and high i8 elements.
1208-
return rewriter.create<vector::InterleaveOp>(loc, low, high);
1209+
int8_t bitsToShiftRight = 8 - numBits;
1210+
Value shiftRightValues = rewriter.create<arith::ConstantOp>(
1211+
loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1212+
Value shr = rewriter.create<arith::ShRSIOp>(loc, shl, shiftRightValues);
1213+
return shr;
12091214
}
12101215

1211-
/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
1212-
/// bitwise ops that take advantage of high-level information to avoid leaving
1213-
/// LLVM to scramble with peephole optimizations.
1214-
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
1215-
Value srcValue) {
1216+
/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
1217+
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1218+
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
1219+
Value srcValue) {
12161220
VectorType srcVecType = cast<VectorType>(srcValue.getType());
12171221
assert(srcVecType.getElementType().isSignlessInteger(4) &&
12181222
"Expected i4 type");
12191223

12201224
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1221-
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
1222-
constexpr int64_t i4Toi8BitwidthFactor = 2;
1223-
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
1224-
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
1225-
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
1225+
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
12261226

1227-
// 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
1228-
// byte are placed in one vector and the high i4 elements in another vector.
1229-
constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
1230-
auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
1231-
loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
1232-
Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
1233-
lowBitsMaskValues);
1234-
constexpr int8_t highBitsToShift = 4;
1235-
auto highShiftValues = rewriter.create<arith::ConstantOp>(
1236-
loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
1237-
Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
1227+
// 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
1228+
// byte are place in one vector and the high i4 elements in another vector.
1229+
Value low = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 0, 4);
1230+
Value high = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 4, 4);
12381231

12391232
// 3. Interleave low and high i8 elements.
12401233
return rewriter.create<vector::InterleaveOp>(loc, low, high);
12411234
}
12421235

12431236
/// Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
1244-
/// bitwise ops that take advantage of high-level information to avoid leaving
1245-
/// LLVM to scramble with peephole optimizations.
1237+
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
12461238
static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
12471239
Value srcValue) {
12481240
VectorType srcVecType = cast<VectorType>(srcValue.getType());
12491241
assert(srcVecType.getElementType().isSignlessInteger(2) &&
12501242
"Expected i2 type");
12511243

12521244
// 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1253-
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
1254-
constexpr int64_t i2Toi8BitwidthFactor = 4;
1255-
i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
1256-
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
1257-
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
1245+
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
12581246

1247+
// 2. Extract each i2 element using shifts
12591248
// Element 0 (bits 0-1)
1260-
constexpr int8_t shiftConst6 = 6;
1261-
auto shiftAttr6 = DenseElementsAttr::get(i8VecType, shiftConst6);
1262-
auto shiftValues6 = rewriter.create<arith::ConstantOp>(loc, shiftAttr6);
1263-
Value shl0 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues6);
1264-
Value elem0 = rewriter.create<arith::ShRSIOp>(loc, shl0, shiftValues6);
1265-
1249+
Value elem0 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 0, 2);
12661250
// Element 1 (bits 2-3)
1267-
constexpr int8_t shiftConst4 = 4;
1268-
auto shiftAttr4 = DenseElementsAttr::get(i8VecType, shiftConst4);
1269-
auto shiftValues4 = rewriter.create<arith::ConstantOp>(loc, shiftAttr4);
1270-
Value shl1 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues4);
1271-
Value elem1 = rewriter.create<arith::ShRSIOp>(loc, shl1, shiftValues6);
1272-
1251+
Value elem1 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 2, 2);
12731252
// Element 2 (bits 4-5)
1274-
constexpr int8_t shiftConst2 = 2;
1275-
auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shiftConst2);
1276-
auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
1277-
Value shl2 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues2);
1278-
Value elem2 = rewriter.create<arith::ShRSIOp>(loc, shl2, shiftValues6);
1279-
1253+
Value elem2 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 4, 2);
12801254
// Element 3 (bits 6-7)
1281-
Value elem3 = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues6);
1255+
Value elem3 = extractNBitsFromVectorSigned(rewriter, loc, i8Vector, 6, 2);
12821256

1283-
// interleave all 4 elements by first interleaving even elements and then odd
1284-
// elem0 = [0,0,0,0]
1285-
// elem1 = [1,1,1,1]
1286-
// elem2 = [2,2,2,2]
1287-
// elem3 = [3,3,3,3]
1257+
// 3. Interleave all 4 elements by first interleaving even elements and then
1258+
// odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
12881259
// 02 = [0,2,0,2]
12891260
// 13 = [1,3,1,3]
12901261
// 0213 = [0,1,2,3]
@@ -1293,56 +1264,72 @@ static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
12931264
return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
12941265
}
12951266

1267+
/// Extracts an unsigned N-bit sequence from each element of an 8-bit vector,
1268+
/// starting at the specified bit index.
1269+
Value extractNBitsFromVectorUnsinged(PatternRewriter &rewriter, Location loc,
1270+
Value src, int bitIdx, int numBits) {
1271+
assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1272+
"Invalid bitIdx range");
1273+
auto srcType = cast<VectorType>(src.getType());
1274+
int8_t bitsToShiftRight = bitIdx;
1275+
Value shr = src;
1276+
if (bitsToShiftRight != 0) {
1277+
Value shiftRightValues = rewriter.create<arith::ConstantOp>(
1278+
loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1279+
shr = rewriter.create<arith::ShRUIOp>(loc, src, shiftRightValues);
1280+
}
1281+
if (bitIdx + numBits == 8) {
1282+
return shr;
1283+
}
1284+
uint8_t lowBitsMask = (1 << numBits) - 1;
1285+
Value lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
1286+
loc, DenseElementsAttr::get(srcType, lowBitsMask));
1287+
return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
1288+
}
1289+
1290+
/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
1291+
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1292+
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
1293+
Value srcValue) {
1294+
VectorType srcVecType = cast<VectorType>(srcValue.getType());
1295+
assert(srcVecType.getElementType().isSignlessInteger(4) &&
1296+
"Expected i4 type");
1297+
1298+
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1299+
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1300+
1301+
// 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
1302+
// byte are placed in one vector and the high i4 elements in another vector.
1303+
Value low = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 0, 4);
1304+
Value high = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 4, 4);
1305+
1306+
// 3. Interleave low and high i8 elements.
1307+
return rewriter.create<vector::InterleaveOp>(loc, low, high);
1308+
}
1309+
12961310
/// Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
1297-
/// bitwise ops that take advantage of high-level information to avoid leaving
1298-
/// LLVM to scramble with peephole optimizations.
1311+
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
12991312
static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
13001313
Value srcValue) {
13011314
VectorType srcVecType = cast<VectorType>(srcValue.getType());
13021315
assert(srcVecType.getElementType().isSignlessInteger(2) &&
13031316
"Expected i2 type");
13041317

13051318
// 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1306-
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
1307-
constexpr int64_t i2Toi8BitwidthFactor = 4;
1308-
i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
1309-
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
1310-
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
1319+
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
13111320

13121321
// 2. Extract each i2 element using shifts and masks
1313-
constexpr uint8_t mask = 3; // Mask for 2 bits: [0000 0011]
1314-
auto maskAttr = DenseElementsAttr::get(i8VecType, mask);
1315-
auto maskValues = rewriter.create<arith::ConstantOp>(loc, maskAttr);
1316-
13171322
// Element 0 (bits 0-1)
1318-
Value elem0 = rewriter.create<arith::AndIOp>(loc, i8Vector, maskValues);
1319-
1323+
Value elem0 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 0, 2);
13201324
// Element 1 (bits 2-3)
1321-
constexpr int8_t shift1 = 2;
1322-
auto shiftAttr1 = DenseElementsAttr::get(i8VecType, shift1);
1323-
auto shiftValues1 = rewriter.create<arith::ConstantOp>(loc, shiftAttr1);
1324-
Value shifted1 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues1);
1325-
Value elem1 = rewriter.create<arith::AndIOp>(loc, shifted1, maskValues);
1326-
1325+
Value elem1 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 2, 2);
13271326
// Element 2 (bits 4-5)
1328-
constexpr int8_t shift2 = 4;
1329-
auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shift2);
1330-
auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
1331-
Value shifted2 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues2);
1332-
Value elem2 = rewriter.create<arith::AndIOp>(loc, shifted2, maskValues);
1333-
1327+
Value elem2 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 4, 2);
13341328
// Element 3 (bits 6-7)
1335-
constexpr int8_t shift3 = 6;
1336-
auto shiftAttr3 = DenseElementsAttr::get(i8VecType, shift3);
1337-
auto shiftValues3 = rewriter.create<arith::ConstantOp>(loc, shiftAttr3);
1338-
Value shifted3 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues3);
1339-
Value elem3 = rewriter.create<arith::AndIOp>(loc, shifted3, maskValues);
1340-
1341-
// interleave all 4 elements by first interleaving even elements and then odd
1342-
// elem0 = [0,0,0,0]
1343-
// elem1 = [1,1,1,1]
1344-
// elem2 = [2,2,2,2]
1345-
// elem3 = [3,3,3,3]
1329+
Value elem3 = extractNBitsFromVectorUnsinged(rewriter, loc, i8Vector, 6, 2);
1330+
1331+
// 3. Interleave all 4 elements by first interleaving even elements and then
1332+
// odd elem0 = [0,0,0,0] elem1 = [1,1,1,1] elem2 = [2,2,2,2] elem3 = [3,3,3,3]
13461333
// 02 = [0,2,0,2]
13471334
// 13 = [1,3,1,3]
13481335
// 0213 = [0,1,2,3]
@@ -1352,8 +1339,7 @@ static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
13521339
}
13531340

13541341
/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1355-
/// ops that take advantage of high-level information to avoid leaving LLVM to
1356-
/// scramble with peephole optimizations.
1342+
/// ops to avoid leaving LLVM to scramble with peephole optimizations.
13571343
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
13581344
Value srcValue) {
13591345
VectorType srcVecType = cast<VectorType>(srcValue.getType());
@@ -1556,20 +1542,30 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
15561542
// Perform the rewrite.
15571543
Value subByteExt;
15581544
if (isSigned) {
1559-
if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2)
1545+
switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
1546+
case 2:
15601547
subByteExt =
15611548
rewriteI2ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1562-
else {
1549+
break;
1550+
case 4:
15631551
subByteExt =
15641552
rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1553+
break;
1554+
default:
1555+
return failure();
15651556
}
15661557
} else {
1567-
if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2) {
1558+
switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
1559+
case 2:
15681560
subByteExt =
15691561
rewriteI2ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
1570-
} else {
1562+
break;
1563+
case 4:
15711564
subByteExt =
15721565
rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
1566+
break;
1567+
default:
1568+
return failure();
15731569
}
15741570
}
15751571

@@ -1611,16 +1607,16 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
16111607
if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
16121608
return failure();
16131609

1610+
// TODO: Add support for truncating to i2.
1611+
if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
1612+
return failure();
1613+
16141614
// Check general alignment preconditions. We invert the src/dst type order
16151615
// to reuse the existing precondition logic.
16161616
if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
16171617
truncOp)))
16181618
return failure();
16191619

1620-
// not supported currently.
1621-
if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
1622-
return failure();
1623-
16241620
// Create a new iX -> i8 truncation op.
16251621
Location loc = truncOp.getLoc();
16261622
auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());

0 commit comments

Comments
 (0)