Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1069,27 +1069,34 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
return commonConversionPrecondition(rewriter, preconditionType, op);
}

/// Verify that source and destination element types meet the precondition for
/// the supported aligned conversion cases. Alignment means that the either the
/// source element type is multiple of the destination element type or the other
/// way around.
/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
/// means that:
/// 1. The `dstType` element type is a multiple of the
/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
/// is not supported). Let this multiple be `N`.
/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
/// not supported).
///
/// NOTE: This method assumes that common conversion preconditions are met.
/// NOTE: This method assumes that common conversion preconditions are met. In
/// particular, the element type of `dstType` is assumed to be a multi-byte
/// type (e.g. i8, i16, i32).
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
VectorType srcType,
VectorType subByteVecType,
VectorType dstType,
Operation *op) {
if (!srcType || !dstType)
if (!subByteVecType || !dstType)
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();

// Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
(dstElemBitwidth % srcElemBitwidth) != 0)
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");

if ((srcType.getShape().back() % 2) != 0)
const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
if ((subByteVecType.getShape().back() % numSrcElemsPerDestElem) != 0)
return rewriter.notifyMatchFailure(
op, "Not an even number of i4 elements in trailing dim");

Expand Down
Loading