From ffd50ccf09eda8a196def71fb45fb5599faadd9e Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Wed, 8 Jan 2025 16:28:19 +0000 Subject: [PATCH] [mlir][vector][nfc] Update `alignedConversionPrecondition` Adds some comments and renames variables to clarify the usage. --- .../Transforms/VectorEmulateNarrowType.cpp | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 181c394edc1d2..d04f302200519 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -1069,19 +1069,25 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter, return commonConversionPrecondition(rewriter, preconditionType, op); } -/// Verify that source and destination element types meet the precondition for -/// the supported aligned conversion cases. Alignment means that the either the -/// source element type is multiple of the destination element type or the other -/// way around. +/// Verify that `subByteVecType` and `dstType` are aligned. Alignment +/// means that: +/// 1. The `dstType` element type is a multiple of the +/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8 +/// is not supported). Let this multiple be `N`. +/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a +/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is +/// not supported). /// -/// NOTE: This method assumes that common conversion preconditions are met. +/// NOTE: This method assumes that common conversion preconditions are met. In +/// particular, the element type of `dstType` is assumed to be a multi-byte +/// type (e.g. i8, i16, i32). static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, - VectorType srcType, + VectorType subByteVecType, VectorType dstType, Operation *op) { - if (!srcType || !dstType) + if (!subByteVecType || !dstType) return rewriter.notifyMatchFailure(op, "Not a supported aligned case"); - unsigned srcElemBitwidth = srcType.getElementTypeBitWidth(); + unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth(); unsigned dstElemBitwidth = dstType.getElementTypeBitWidth(); // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now. @@ -1089,7 +1095,8 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, (dstElemBitwidth % srcElemBitwidth) != 0) return rewriter.notifyMatchFailure(op, "Not a supported aligned case"); - if ((srcType.getShape().back() % 2) != 0) + const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; + if ((subByteVecType.getShape().back() % numSrcElemsPerDestElem) != 0) return rewriter.notifyMatchFailure( op, "Not an even number of i4 elements in trailing dim");