Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 170 additions & 60 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,12 +1090,16 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();

// Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
(dstElemBitwidth % srcElemBitwidth) != 0)
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
if (dstElemBitwidth < 8)
return rewriter.notifyMatchFailure(
op, "the bitwidth of dstType must be greater than or equal to 8");
if (dstElemBitwidth % srcElemBitwidth != 0)
return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
return rewriter.notifyMatchFailure(
op, "only src bitwidth of 2 or 4 is supported at this moment");

const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
const int numSrcElemsPerDestElem = 8 / srcElemBitwidth;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why hard-code this to 8?

Also, comment for L1105 (sadly GitHub doesn't allow comments for things outside the diff) - please update the error msg.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this check since the dst in this case will always be a i8 vector.

For example:

%0 = arith.extui %a : vector<4xi2> to vector<4xi32>

should be possible because its only important that i can pack the 4 i2 values into a 1xi8 vector (otherwise i would have a invalid bitcast).

But i also added a test that checks something like this fails:

%0 = arith.extui %a : vector<2xi2> to vector<2xi32>

if ((subByteVecType.getShape().back() % numSrcElemsPerDestElem) != 0)
return rewriter.notifyMatchFailure(
op, "Not an even number of i4 elements in trailing dim");
Expand Down Expand Up @@ -1179,70 +1183,166 @@ Value BitCastRewriter::genericRewriteStep(
return runningResult;
}

/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
/// bitwise ops that take advantage of high-level information to avoid leaving
/// LLVM to scramble with peephole optimizations.
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
Value srcValue) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
assert(srcVecType.getElementType().isSignlessInteger(4) &&
"Expected i4 type");
/// Bitcasts the aligned `subByteVec` vector to a vector of i8.
/// Where aligned means it satisfies the alignedConversionPreconditions.
///
/// Example:
/// vector<16x16xi2> -> vector<16x2xi8>
/// vector<16x16xi4> -> vector<16x4xi8>
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
Value subByteVec) {
auto srcVecType = cast<VectorType>(subByteVec.getType());
int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
assert(8 % srcBitwidth == 0 &&
"Unsupported sub-byte type (not a divisor of i8)");
int64_t bitwidthFactor = 8 / srcBitwidth;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To keep the naming consistent within the file, could this be numSrcElemsPerByte?

Similar variable elsewhere:

const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, there are some other places that refer to the same name but im not sure if they refer to the same thing

SmallVector<int64_t> vecShape(srcVecType.getShape());
// Adjust last dimension of the vector, so the total size remains the same.
vecShape.back() = vecShape.back() / bitwidthFactor;
auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
return rewriter.create<vector::BitCastOp>(loc, i8VecType, subByteVec);
}

// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
constexpr int64_t i4Toi8BitwidthFactor = 2;
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
/// Extracts a signed N-bit sequence from each element of an 8-bit vector,
/// starting at the specified bit index.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused about the naming and this description.

IIUC, this method will:

  • for every byte b in src (which is a vector of bytes),
  • extracts numBits starting at bitIdx (let's call it inputVal), and
  • returns a byte matching the value encoded in inputVal.

So this method is more like extractNBitsAndReturnAsByte?

Copy link
Contributor Author

@ziereis ziereis Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • for every byte b in src (which is a vector of bytes),
  • extracts numBits starting at bitIdx (let's call it inputVal), and
  • returns a byte matching the value encoded in inputVal.

it will extract numBits for every byte of src at bitIdx and will return a vector of bytes, the resultType will always be the same as the srcType.
So for example lets say numBits is 4, it will treat the inputVal as a i4 and (sign)ext it to a i8 value.

im not sure about the name either, maybe extractNBitsAnd(Sign)ExtToI8 ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extractNbitsPerByteAndExtendToI8?

Am I correct that this method assumes that the src and dst element type is i8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

/// The `bitIdx` starts at 0 from the LSB and moves to the left.
///
/// Example for a single element:
/// Extract numBits=2 starting at bitIdx=2
/// src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
/// target = [. . . . ^ ^ . .]
///
/// The target sequence is [11](decimal=-1) as signed 2-bit integer.
/// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
///
/// src = [01 01 11 10]
/// shl = arith.shl(src, 4) -> [11 10 00 00]
/// result = arith.shrsi(shl, 6) -> [11 11 11 11]
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter,
Location loc, Value src,
int bitIdx, int numBits) {
assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
"Invalid bitIdx range");
auto srcType = cast<VectorType>(src.getType());
Value shl = src;
int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
if (bitsToShiftLeft != 0) {
Value shiftLeftValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
shl = rewriter.create<arith::ShLIOp>(loc, src, shiftLeftValues);
}

// 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
// byte are place in one vector and the high i4 elements in another vector.
constexpr int8_t bitsToShift = 4;
auto shiftValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(i8VecType, bitsToShift));
Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
int8_t bitsToShiftRight = 8 - numBits;
Value shiftRightValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
Value shr = rewriter.create<arith::ShRSIOp>(loc, shl, shiftRightValues);
return shr;
}

// 3. Interleave low and high i8 elements.
return rewriter.create<vector::InterleaveOp>(loc, low, high);
/// Extracts an unsigned N-bit sequence from each element of an 8-bit vector,
/// starting at the specified bit index.
/// The `bitIdx` starts at 0 from the LSB and moves to the left.
///
/// Example for a single element:
/// Extract numBits=2 starting at bitIdx=2
/// src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
/// target = [. . . . ^ ^ . .]
///
/// The target sequence is [10](decimal=2) as unsigned 2-bit integer.
/// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
///
/// src = [01 01 10 10]
/// mask = [00 00 00 11]
/// shr = arith.shrui(src, 2) = [00 01 01 10]
/// result = arith.andi(shr, mask) = [00 00 00 10]
/// NOTE: Similarly to extractNBitsFromVectorSigned, this could be achieved by
/// using arith::ShLIOp + arith::ShRUIOp instead of the masking. However, by
/// using arith::ShRUIOp + arith::AndIOp, we are eliminating shift left when the
/// index is 0.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter,
Location loc, Value src,
int bitIdx, int numBits) {
assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
"Invalid bitIdx range");
auto srcType = cast<VectorType>(src.getType());
int8_t bitsToShiftRight = bitIdx;
Value shr = src;
if (bitsToShiftRight != 0) {
Value shiftRightValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
shr = rewriter.create<arith::ShRUIOp>(loc, src, shiftRightValues);
}
if (bitIdx + numBits == 8) {
return shr;
}
uint8_t lowBitsMask = (1 << numBits) - 1;
Value lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(srcType, lowBitsMask));
return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
}

/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
/// bitwise ops that take advantage of high-level information to avoid leaving
/// LLVM to scramble with peephole optimizations.
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
Value srcValue) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
using ExtractNBitsFn =
std::function<Value(PatternRewriter &, Location, Value, int, int)>;

/// Rewrite the i4 -> i8 extension into a sequence of shuffles and
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc,
Value srcValue, const ExtractNBitsFn &extFn) {
auto srcVecType = cast<VectorType>(srcValue.getType());
assert(srcVecType.getElementType().isSignlessInteger(4) &&
"Expected i4 type");

// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
constexpr int64_t i4Toi8BitwidthFactor = 2;
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);

// 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
// byte are placed in one vector and the high i4 elements in another vector.
constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
lowBitsMaskValues);
constexpr int8_t highBitsToShift = 4;
auto highShiftValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);

// 2. Extend i4 elements to i8 elements. Low i4 elemens of each
// byte are place in one vector and the high i4 elements in another vector.
Value low = extFn(rewriter, loc, i8Vector, 0, 4);
Value high = extFn(rewriter, loc, i8Vector, 4, 4);

// 3. Interleave low and high i8 elements.
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}

/// Rewrite the i2 -> i8 extension into a sequence of shuffles and
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc,
Value srcValue, const ExtractNBitsFn &extFn) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
assert(srcVecType.getElementType().isSignlessInteger(2) &&
"Expected i2 type");

// 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);

// 2. Extract each i2 element
// Positon 0 (bits 0-1)
Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
// Position 1 (bits 2-3)
Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
// Position 2 (bits 4-5)
Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
// Position 3 (bits 6-7)
Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);

// 3. Interleave all 4 elements by first interleaving
// even elements and then odd
// vec0 = [0,0,0,0],...
// vec1 = [1,1,1,1],...
// vec2 = [2,2,2,2],...
// vec3 = [3,3,3,3],...
// 02 = [0,2,0,2,...],...
// 13 = [1,3,1,3,...],...
// 0213 = [0,1,2,3,...],...
Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, vec0, vec2);
Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, vec1, vec3);
return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
}

/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
/// ops that take advantage of high-level information to avoid leaving LLVM to
/// scramble with peephole optimizations.
/// ops to avoid leaving LLVM to scramble with peephole optimizations.
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
Value srcValue) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
Expand Down Expand Up @@ -1443,13 +1543,19 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
return failure();

// Perform the rewrite.
Location loc = conversionOp.getLoc();
const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
: extractNBitsPerByteAndExtendToI8;
Value subByteExt;
if (isSigned) {
subByteExt =
rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
} else {
subByteExt =
rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
case 2:
subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
break;
case 4:
subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
break;
default:
return failure();
}

// Finalize the rewrite.
Expand Down Expand Up @@ -1490,6 +1596,10 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
return failure();

// TODO: Add support for truncating to i2.
if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
return failure();

// Check general alignment preconditions. We invert the src/dst type order
// to reuse the existing precondition logic.
if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
Expand Down
Loading
Loading