diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 51e72753ff162..cf6efaa04ae44 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -519,7 +519,7 @@ struct ConvertVectorStore final : OpConversionPattern { auto origElements = valueToStore.getType().getNumElements(); // Note, per-element-alignment was already verified above. - bool isFullyAligned = origElements % emulatedPerContainerElem == 0; + bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); @@ -535,8 +535,8 @@ struct ConvertVectorStore final : OpConversionPattern { getAsOpFoldResult(adaptor.getIndices())); std::optional foldedNumFrontPadElems = - isFullyAligned ? 0 - : getConstantIntValue(linearizedInfo.intraDataOffset); + isDivisibleInSize ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); if (!foldedNumFrontPadElems) { return rewriter.notifyMatchFailure( @@ -554,7 +554,7 @@ struct ConvertVectorStore final : OpConversionPattern { // need unaligned emulation because the store address is aligned and the // source is a whole byte. bool emulationRequiresPartialStores = - !isFullyAligned || *foldedNumFrontPadElems != 0; + !isDivisibleInSize || *foldedNumFrontPadElems != 0; if (!emulationRequiresPartialStores) { // Basic case: storing full bytes. auto numElements = origElements / emulatedPerContainerElem; @@ -881,7 +881,7 @@ struct ConvertVectorLoad final : OpConversionPattern { auto origElements = op.getVectorType().getNumElements(); // Note, per-element-alignment was already verified above. - bool isFullyAligned = origElements % emulatedPerContainerElem == 0; + bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); @@ -897,8 +897,8 @@ struct ConvertVectorLoad final : OpConversionPattern { getAsOpFoldResult(adaptor.getIndices())); std::optional foldedIntraVectorOffset = - isFullyAligned ? 0 - : getConstantIntValue(linearizedInfo.intraDataOffset); + isDivisibleInSize ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); // Always load enough elements which can cover the original elements. int64_t maxintraDataOffset = @@ -915,7 +915,7 @@ struct ConvertVectorLoad final : OpConversionPattern { result = dynamicallyExtractSubVector( rewriter, loc, dyn_cast>(result), resultVector, linearizedInfo.intraDataOffset, origElements); - } else if (!isFullyAligned) { + } else if (!isDivisibleInSize) { result = staticallyExtractSubvector( rewriter, loc, result, *foldedIntraVectorOffset, origElements); } @@ -1002,7 +1002,7 @@ struct ConvertVectorMaskedLoad final auto origType = op.getVectorType(); auto origElements = origType.getNumElements(); // Note, per-element-alignment was already verified above. - bool isFullyAligned = origElements % emulatedPerContainerElem == 0; + bool isDivisibleInSize = origElements % emulatedPerContainerElem == 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); @@ -1017,8 +1017,8 @@ struct ConvertVectorMaskedLoad final getAsOpFoldResult(adaptor.getIndices())); std::optional foldedIntraVectorOffset = - isFullyAligned ? 0 - : getConstantIntValue(linearizedInfo.intraDataOffset); + isDivisibleInSize ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1); @@ -1042,7 +1042,7 @@ struct ConvertVectorMaskedLoad final passthru = dynamicallyInsertSubVector( rewriter, loc, passthru, emptyVector, linearizedInfo.intraDataOffset, origElements); - } else if (!isFullyAligned) { + } else if (!isDivisibleInSize) { passthru = staticallyInsertSubvector(rewriter, loc, passthru, emptyVector, *foldedIntraVectorOffset); } @@ -1070,7 +1070,7 @@ struct ConvertVectorMaskedLoad final mask = dynamicallyInsertSubVector(rewriter, loc, mask, emptyMask, linearizedInfo.intraDataOffset, origElements); - } else if (!isFullyAligned) { + } else if (!isDivisibleInSize) { mask = staticallyInsertSubvector(rewriter, loc, op.getMask(), emptyMask, *foldedIntraVectorOffset); } @@ -1081,7 +1081,7 @@ struct ConvertVectorMaskedLoad final result = dynamicallyExtractSubVector( rewriter, loc, result, op.getPassThru(), linearizedInfo.intraDataOffset, origElements); - } else if (!isFullyAligned) { + } else if (!isDivisibleInSize) { result = staticallyExtractSubvector( rewriter, loc, result, *foldedIntraVectorOffset, origElements); } @@ -1091,6 +1091,38 @@ struct ConvertVectorMaskedLoad final } }; +/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy` +/// +/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g. +/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy` +/// (a multi-byte scalar, e.g. i16), where N is some integer. +/// +/// Put differently, this method checks whether this would be valid: +/// +/// vector.bitcast subByteVecTy into vector +/// +/// EXAMPLES: +/// * vector<4xi4> -> i16 - yes (N = 1) +/// * vector<4xi4> -> i8 - yes (N = 2) +/// * vector<3xi4> -> i8 - no (N would have to be 1.5) +/// * vector<3xi2> -> i16 - no (N would have to be 0.5) +static bool fitsInMultiByteContainerTy(VectorType subByteVecTy, + Type multiByteScalarTy) { + assert((isa(multiByteScalarTy)) && "Not scalar!"); + + int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth(); + int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth(); + + assert(subByteBits < 8 && "Not a sub-byte scalar type!"); + assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!"); + assert(multiByteBits % subByteBits == 0 && "Unalagined element types!"); + + int elemsPerMultiByte = multiByteBits / subByteBits; + + // TODO: This is a bit too restrictive for vectors rank > 1. + return subByteVecTy.getShape().back() % elemsPerMultiByte == 0; +} + //===----------------------------------------------------------------------===// // ConvertVectorTransferRead //===----------------------------------------------------------------------===// @@ -1127,7 +1159,8 @@ struct ConvertVectorTransferRead final auto origElements = op.getVectorType().getNumElements(); // Note, per-element-alignment was already verified above. - bool isFullyAligned = origElements % emulatedPerContainerElem == 0; + bool isDivisibleInSize = + fitsInMultiByteContainerTy(op.getVectorType(), containerElemTy); auto newPadding = rewriter.create(loc, containerElemTy, adaptor.getPadding()); @@ -1146,8 +1179,8 @@ struct ConvertVectorTransferRead final getAsOpFoldResult(adaptor.getIndices())); std::optional foldedIntraVectorOffset = - isFullyAligned ? 0 - : getConstantIntValue(linearizedInfo.intraDataOffset); + isDivisibleInSize ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1); @@ -1171,7 +1204,7 @@ struct ConvertVectorTransferRead final result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros, linearizedInfo.intraDataOffset, origElements); - } else if (!isFullyAligned) { + } else if (!isDivisibleInSize) { result = staticallyExtractSubvector( rewriter, loc, result, *foldedIntraVectorOffset, origElements); } @@ -1428,41 +1461,69 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter, return commonConversionPrecondition(rewriter, preconditionType, op); } -/// Verify that `subByteVecType` and `dstType` are aligned. Alignment -/// means that: -/// 1. The `dstType` element type is a multiple of the -/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8 -/// is not supported). Let this multiple be `N`. -/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a -/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is -/// not supported). +/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned. +/// +/// Alignment means that `subByteVecTy` can be packed into a vector of +/// `containerTy` elements. More specifically: +/// 1. The bit-width of `containerTy` is a multiple of the +/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16` +/// this multiple is 4. +/// 2. The multiple from 1. above divides evenly the number of the (trailing) +/// elements in `subByteVecTy`. +/// +/// EXAMPLE 1: +/// `subByteVecTy = vector<2xi4>`, and +/// `containerTy = i16` +/// +/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_. +/// +/// EXAMPLE 2: +/// `subByteVecTy = vector<3xi4>`, and +/// `containerTy = i16` +/// +/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_. +/// +/// EXAMPLE 3: +/// `subByteVecTy = vector<3xi3>`, and +/// `containerTy = i16` +/// +/// 16 _is not_ a multiple of 3, hence the conditions are _not met_. /// /// NOTE: This method assumes that common conversion preconditions are met. In -/// particular, the element type of `dstType` is assumed to be a multi-byte -/// type (e.g. i8, i16, i32). +/// particular, `containerTy` is assumed to be a +/// multi-byte scalar type (e.g., i8, i16, i32). static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, - VectorType subByteVecType, - VectorType dstType, + VectorType subByteVecTy, + Type containerTy, Operation *op) { - if (!subByteVecType || !dstType) - return rewriter.notifyMatchFailure(op, "Not a supported aligned case"); - unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth(); - unsigned dstElemBitwidth = dstType.getElementTypeBitWidth(); + assert(containerTy.isIntOrFloat() && + "container element type is not a scalar"); - if (dstElemBitwidth < 8) - return rewriter.notifyMatchFailure( - op, "the bitwidth of dstType must be greater than or equal to 8"); - if (dstElemBitwidth % srcElemBitwidth != 0) - return rewriter.notifyMatchFailure(op, "unaligned cases are not supported"); - if (srcElemBitwidth != 2 && srcElemBitwidth != 4) + // TODO: This is validating the inputs rather than checking the conditions + // documented above. Replace with an assert. + if (!subByteVecTy) + return rewriter.notifyMatchFailure(op, "not a vector!"); + + unsigned subByteBits = subByteVecTy.getElementTypeBitWidth(); + unsigned containerBits = containerTy.getIntOrFloatBitWidth(); + + // Enforced by the common pre-conditions. + assert(containerBits % 8 == 0 && "Not a multi-byte scalar type!"); + + // TODO: Add support other widths (when/if needed) + if (subByteBits != 2 && subByteBits != 4) return rewriter.notifyMatchFailure( - op, "only src bitwidth of 2 or 4 is supported at this moment"); + op, "only 2-bit and 4-bit sub-byte type is supported at this moment"); + + // Condition 1 ("per-element" alignment) + if (containerBits % subByteBits != 0) + return rewriter.notifyMatchFailure(op, "unalagined element types"); - const int numSrcElemsPerByte = 8 / srcElemBitwidth; - if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0) + // Condition 2 ("full" alignment) + if (!fitsInMultiByteContainerTy(subByteVecTy, containerTy)) return rewriter.notifyMatchFailure( - op, "the trailing dimension of the input vector of sub-bytes must be a " - "multiple of 8 / "); + op, "not possible to fit this sub-byte vector type into a vector of " + "the given multi-byte type"); return success(); } @@ -1899,8 +1960,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern { return failure(); // Check general alignment preconditions. - if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType, - conversionOp))) + if (failed(alignedConversionPrecondition( + rewriter, srcVecType, + /*containerTy=*/rewriter.getI8Type(), conversionOp))) return failure(); // Perform the rewrite. @@ -1964,8 +2026,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern { // Check general alignment preconditions. We invert the src/dst type order // to reuse the existing precondition logic. - if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType, - truncOp))) + if (failed(alignedConversionPrecondition( + rewriter, dstVecType, + /*containerTy=*/rewriter.getI8Type(), truncOp))) return failure(); // Create a new iX -> i8 truncation op.