Skip to content

Commit 21ba7ae

Browse files
authored
[mlir][vector][nfc] Update alignedConversionPrecondition (llvm#122136)
Adds some comments and re-name variables to clarify the usage.
1 parent 6f9e688 commit 21ba7ae

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,27 +1069,34 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
10691069
return commonConversionPrecondition(rewriter, preconditionType, op);
10701070
}
10711071

1072-
/// Verify that source and destination element types meet the precondition for
1073-
/// the supported aligned conversion cases. Alignment means that the either the
1074-
/// source element type is multiple of the destination element type or the other
1075-
/// way around.
1072+
/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
1073+
/// means that:
1074+
/// 1. The `dstType` element type is a multiple of the
1075+
/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
1076+
/// is not supported). Let this multiple be `N`.
1077+
/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
1078+
/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
1079+
/// not supported).
10761080
///
1077-
/// NOTE: This method assumes that common conversion preconditions are met.
1081+
/// NOTE: This method assumes that common conversion preconditions are met. In
1082+
/// particular, the element type of `dstType` is assumed to be a multi-byte
1083+
/// type (e.g. i8, i16, i32).
10781084
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1079-
VectorType srcType,
1085+
VectorType subByteVecType,
10801086
VectorType dstType,
10811087
Operation *op) {
1082-
if (!srcType || !dstType)
1088+
if (!subByteVecType || !dstType)
10831089
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
1084-
unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
1090+
unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
10851091
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
10861092

10871093
// Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
10881094
if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
10891095
(dstElemBitwidth % srcElemBitwidth) != 0)
10901096
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
10911097

1092-
if ((srcType.getShape().back() % 2) != 0)
1098+
const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
1099+
if ((subByteVecType.getShape().back() % numSrcElemsPerDestElem) != 0)
10931100
return rewriter.notifyMatchFailure(
10941101
op, "Not an even number of i4 elements in trailing dim");
10951102

0 commit comments

Comments
 (0)