diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 95064083b21d4..373b8a8822318 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -45,6 +45,10 @@ using namespace mlir; #define DBGSNL() (llvm::dbgs() << "\n") #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") +//===----------------------------------------------------------------------===// +// Utils +//===----------------------------------------------------------------------===// + /// Returns a compressed mask for the emulated vector. For example, when /// emulating an eight-element `i8` vector with `i32` (i.e. when the source /// elements span two dest elements), this method compresses `vector<8xi1>` @@ -282,13 +286,15 @@ emulatedVectorLoad(OpBuilder &rewriter, Location loc, Value base, OpFoldResult linearizedIndices, int64_t numEmultedElementsToLoad, Type origElemType, Type emulatedElemType) { - auto scale = emulatedElemType.getIntOrFloatBitWidth() / - origElemType.getIntOrFloatBitWidth(); + auto elementsPerContainerType = emulatedElemType.getIntOrFloatBitWidth() / + origElemType.getIntOrFloatBitWidth(); auto newLoad = rewriter.create( loc, VectorType::get(numEmultedElementsToLoad, emulatedElemType), base, getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); return rewriter.create( - loc, VectorType::get(numEmultedElementsToLoad * scale, origElemType), + loc, + VectorType::get(numEmultedElementsToLoad * elementsPerContainerType, + origElemType), newLoad); } @@ -298,6 +304,7 @@ namespace { // ConvertVectorStore //===----------------------------------------------------------------------===// +// TODO: Document-me struct ConvertVectorStore final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -314,14 +321,14 @@ struct ConvertVectorStore final : OpConversionPattern { auto convertedType = cast(adaptor.getBase().getType()); Type oldElementType = op.getValueToStore().getType().getElementType(); Type newElementType = convertedType.getElementType(); - int srcBits = oldElementType.getIntOrFloatBitWidth(); - int dstBits = newElementType.getIntOrFloatBitWidth(); + int oldBits = oldElementType.getIntOrFloatBitWidth(); + int newBits = newElementType.getIntOrFloatBitWidth(); - if (dstBits % srcBits != 0) { - return rewriter.notifyMatchFailure( - op, "only dstBits % srcBits == 0 supported"); + // Check per-element alignment. + if (newBits % oldBits != 0) { + return rewriter.notifyMatchFailure(op, "unalagined element types"); } - int scale = dstBits / srcBits; + int elementsPerContainerType = newBits / oldBits; // Adjust the number of elements to store when emulating narrow types. // Here only the 1-D vector store is considered, and the N-D memref types @@ -337,7 +344,7 @@ struct ConvertVectorStore final : OpConversionPattern { // vector<4xi8> auto origElements = op.getValueToStore().getType().getNumElements(); - if (origElements % scale != 0) + if (origElements % elementsPerContainerType != 0) return failure(); auto stridedMetadata = @@ -346,13 +353,13 @@ struct ConvertVectorStore final : OpConversionPattern { OpFoldResult linearizedIndices; std::tie(std::ignore, linearizedIndices) = memref::getLinearizedMemRefOffsetAndSize( - rewriter, loc, srcBits, dstBits, + rewriter, loc, oldBits, newBits, stridedMetadata.getConstifiedMixedOffset(), stridedMetadata.getConstifiedMixedSizes(), stridedMetadata.getConstifiedMixedStrides(), getAsOpFoldResult(adaptor.getIndices())); - auto numElements = origElements / scale; + auto numElements = origElements / elementsPerContainerType; auto bitCast = rewriter.create( loc, VectorType::get(numElements, newElementType), op.getValueToStore()); @@ -368,6 +375,7 @@ struct ConvertVectorStore final : OpConversionPattern { // ConvertVectorMaskedStore //===----------------------------------------------------------------------===// +// TODO: Document-me struct ConvertVectorMaskedStore final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -385,17 +393,17 @@ struct ConvertVectorMaskedStore final auto convertedType = cast(adaptor.getBase().getType()); Type oldElementType = op.getValueToStore().getType().getElementType(); Type newElementType = convertedType.getElementType(); - int srcBits = oldElementType.getIntOrFloatBitWidth(); - int dstBits = newElementType.getIntOrFloatBitWidth(); + int oldBits = oldElementType.getIntOrFloatBitWidth(); + int newBits = newElementType.getIntOrFloatBitWidth(); - if (dstBits % srcBits != 0) { - return rewriter.notifyMatchFailure( - op, "only dstBits % srcBits == 0 supported"); + // Check per-element alignment. + if (newBits % oldBits != 0) { + return rewriter.notifyMatchFailure(op, "unalagined element types"); } - int scale = dstBits / srcBits; + int elementsPerContainerType = newBits / oldBits; int origElements = op.getValueToStore().getType().getNumElements(); - if (origElements % scale != 0) + if (origElements % elementsPerContainerType != 0) return failure(); auto stridedMetadata = @@ -404,7 +412,7 @@ struct ConvertVectorMaskedStore final memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndicesOfr) = memref::getLinearizedMemRefOffsetAndSize( - rewriter, loc, srcBits, dstBits, + rewriter, loc, oldBits, newBits, stridedMetadata.getConstifiedMixedOffset(), stridedMetadata.getConstifiedMixedSizes(), stridedMetadata.getConstifiedMixedStrides(), @@ -444,12 +452,13 @@ struct ConvertVectorMaskedStore final // // FIXME: Make an example based on the comment above work (see #115460 for // reproducer). - FailureOr newMask = - getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale); + FailureOr newMask = getCompressedMaskOp( + rewriter, loc, op.getMask(), origElements, elementsPerContainerType); if (failed(newMask)) return failure(); - auto numElements = (origElements + scale - 1) / scale; + auto numElements = (origElements + elementsPerContainerType - 1) / + elementsPerContainerType; auto newType = VectorType::get(numElements, newElementType); auto passThru = rewriter.create( loc, newType, rewriter.getZeroAttr(newType)); @@ -458,7 +467,8 @@ struct ConvertVectorMaskedStore final loc, newType, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0), passThru); - auto newBitCastType = VectorType::get(numElements * scale, oldElementType); + auto newBitCastType = + VectorType::get(numElements * elementsPerContainerType, oldElementType); Value valueToStore = rewriter.create(loc, newBitCastType, newLoad); valueToStore = rewriter.create( @@ -477,6 +487,7 @@ struct ConvertVectorMaskedStore final // ConvertVectorLoad //===----------------------------------------------------------------------===// +// TODO: Document-me struct ConvertVectorLoad final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -493,14 +504,14 @@ struct ConvertVectorLoad final : OpConversionPattern { auto convertedType = cast(adaptor.getBase().getType()); Type oldElementType = op.getType().getElementType(); Type newElementType = convertedType.getElementType(); - int srcBits = oldElementType.getIntOrFloatBitWidth(); - int dstBits = newElementType.getIntOrFloatBitWidth(); + int oldBits = oldElementType.getIntOrFloatBitWidth(); + int newBits = newElementType.getIntOrFloatBitWidth(); - if (dstBits % srcBits != 0) { - return rewriter.notifyMatchFailure( - op, "only dstBits % srcBits == 0 supported"); + // Check per-element alignment. + if (newBits % oldBits != 0) { + return rewriter.notifyMatchFailure(op, "unalagined element types"); } - int scale = dstBits / srcBits; + int elementsPerContainerType = newBits / oldBits; // Adjust the number of elements to load when emulating narrow types, // and then cast back to the original type with vector.bitcast op. @@ -532,7 +543,8 @@ struct ConvertVectorLoad final : OpConversionPattern { // compile time as they must be constants. auto origElements = op.getVectorType().getNumElements(); - bool isUnalignedEmulation = origElements % scale != 0; + // Note, per-element-alignment was already verified above. + bool isFullyAligned = origElements % elementsPerContainerType == 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); @@ -541,21 +553,21 @@ struct ConvertVectorLoad final : OpConversionPattern { memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndices) = memref::getLinearizedMemRefOffsetAndSize( - rewriter, loc, srcBits, dstBits, + rewriter, loc, oldBits, newBits, stridedMetadata.getConstifiedMixedOffset(), stridedMetadata.getConstifiedMixedSizes(), stridedMetadata.getConstifiedMixedStrides(), getAsOpFoldResult(adaptor.getIndices())); std::optional foldedIntraVectorOffset = - isUnalignedEmulation - ? getConstantIntValue(linearizedInfo.intraDataOffset) - : 0; + isFullyAligned ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); // Always load enough elements which can cover the original elements. - int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); - auto numElements = - llvm::divideCeil(maxintraDataOffset + origElements, scale); + int64_t maxintraDataOffset = + foldedIntraVectorOffset.value_or(elementsPerContainerType - 1); + auto numElements = llvm::divideCeil(maxintraDataOffset + origElements, + elementsPerContainerType); Value result = emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices, numElements, oldElementType, newElementType); @@ -566,7 +578,7 @@ struct ConvertVectorLoad final : OpConversionPattern { result = dynamicallyExtractSubVector( rewriter, loc, dyn_cast>(result), resultVector, linearizedInfo.intraDataOffset, origElements); - } else if (isUnalignedEmulation) { + } else if (!isFullyAligned) { result = staticallyExtractSubvector(rewriter, loc, op.getType(), result, *foldedIntraVectorOffset, origElements); @@ -580,6 +592,7 @@ struct ConvertVectorLoad final : OpConversionPattern { // ConvertVectorMaskedLoad //===----------------------------------------------------------------------===// +// TODO: Document-me struct ConvertVectorMaskedLoad final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -596,14 +609,14 @@ struct ConvertVectorMaskedLoad final auto convertedType = cast(adaptor.getBase().getType()); Type oldElementType = op.getType().getElementType(); Type newElementType = convertedType.getElementType(); - int srcBits = oldElementType.getIntOrFloatBitWidth(); - int dstBits = newElementType.getIntOrFloatBitWidth(); + int oldBits = oldElementType.getIntOrFloatBitWidth(); + int newBits = newElementType.getIntOrFloatBitWidth(); - if (dstBits % srcBits != 0) { - return rewriter.notifyMatchFailure( - op, "only dstBits % srcBits == 0 supported"); + // Check per-element alignment. + if (newBits % oldBits != 0) { + return rewriter.notifyMatchFailure(op, "unalagined element types"); } - int scale = dstBits / srcBits; + int elementsPerContainerType = newBits / oldBits; // Adjust the number of elements to load when emulating narrow types, // and then cast back to the original type with vector.bitcast op. @@ -649,7 +662,7 @@ struct ConvertVectorMaskedLoad final // subvector at the proper offset after bit-casting. auto origType = op.getVectorType(); auto origElements = origType.getNumElements(); - bool isUnalignedEmulation = origElements % scale != 0; + bool isUnalignedEmulation = origElements % elementsPerContainerType != 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); @@ -657,7 +670,7 @@ struct ConvertVectorMaskedLoad final memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndices) = memref::getLinearizedMemRefOffsetAndSize( - rewriter, loc, srcBits, dstBits, + rewriter, loc, oldBits, newBits, stridedMetadata.getConstifiedMixedOffset(), stridedMetadata.getConstifiedMixedSizes(), stridedMetadata.getConstifiedMixedStrides(), @@ -668,18 +681,21 @@ struct ConvertVectorMaskedLoad final ? getConstantIntValue(linearizedInfo.intraDataOffset) : 0; - int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); - FailureOr newMask = getCompressedMaskOp( - rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset); + int64_t maxIntraDataOffset = + foldedIntraVectorOffset.value_or(elementsPerContainerType - 1); + FailureOr newMask = + getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, + elementsPerContainerType, maxIntraDataOffset); if (failed(newMask)) return failure(); Value passthru = op.getPassThru(); - auto numElements = - llvm::divideCeil(maxIntraDataOffset + origElements, scale); + auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements, + elementsPerContainerType); auto loadType = VectorType::get(numElements, newElementType); - auto newBitcastType = VectorType::get(numElements * scale, oldElementType); + auto newBitcastType = + VectorType::get(numElements * elementsPerContainerType, oldElementType); auto emptyVector = rewriter.create( loc, newBitcastType, rewriter.getZeroAttr(newBitcastType)); @@ -706,8 +722,8 @@ struct ConvertVectorMaskedLoad final rewriter.create(loc, newBitcastType, newLoad); Value mask = op.getMask(); - auto newSelectMaskType = - VectorType::get(numElements * scale, rewriter.getI1Type()); + auto newSelectMaskType = VectorType::get( + numElements * elementsPerContainerType, rewriter.getI1Type()); // TODO: try to fold if op's mask is constant auto emptyMask = rewriter.create( loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType)); @@ -737,10 +753,43 @@ struct ConvertVectorMaskedLoad final } }; +/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy` +/// +/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g. +/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy` +/// (a multi-byte scalar, e.g. i16), where N is some integer. +/// +/// Put differently, this method checks whether this would be valid: +/// +/// vector.bitcast subByteVecTy into vector +/// +/// EXAMPLES: +/// * vector<4xi4> -> i16 - yes (N = 1) +/// * vector<4xi4> -> i8 - yes (N = 2) +/// * vector<3xi4> -> i8 - no (N would have to be 1.5) +/// * vector<3xi2> -> i16 - no (N would have to be 0.5) +static bool isSubByteVecFittable(VectorType subByteVecTy, + Type multiByteScalarTy) { + assert((isa(multiByteScalarTy)) && "Not scalar!"); + + int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth(); + int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth(); + + assert(subByteBits < 8 && "Not a sub-byte scalar type!"); + assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!"); + assert(multiByteBits % subByteBits == 0 && "Unalagined element types!"); + + int elemsPerMultiByte = multiByteBits / subByteBits; + + // TODO: This is a bit too restrictive for vectors rank > 1. + return subByteVecTy.getShape().back() % elemsPerMultiByte == 0; +} + //===----------------------------------------------------------------------===// // ConvertVectorTransferRead //===----------------------------------------------------------------------===// +// TODO: Document-me struct ConvertVectorTransferRead final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -758,18 +807,20 @@ struct ConvertVectorTransferRead final auto convertedType = cast(adaptor.getSource().getType()); Type oldElementType = op.getType().getElementType(); Type newElementType = convertedType.getElementType(); - int srcBits = oldElementType.getIntOrFloatBitWidth(); - int dstBits = newElementType.getIntOrFloatBitWidth(); + int oldBits = oldElementType.getIntOrFloatBitWidth(); + int newBits = newElementType.getIntOrFloatBitWidth(); - if (dstBits % srcBits != 0) { - return rewriter.notifyMatchFailure( - op, "only dstBits % srcBits == 0 supported"); + // Check per-element alignment. + if (newBits % oldBits != 0) { + return rewriter.notifyMatchFailure(op, "unalagined element types"); } - int scale = dstBits / srcBits; + int elementsPerContainerType = newBits / oldBits; auto origElements = op.getVectorType().getNumElements(); - bool isUnalignedEmulation = origElements % scale != 0; + // Note, per-element-alignment was already verified above. + bool isFullyAligned = + isSubByteVecFittable(op.getVectorType(), newElementType); auto newPadding = rewriter.create(loc, newElementType, adaptor.getPadding()); @@ -781,20 +832,20 @@ struct ConvertVectorTransferRead final memref::LinearizedMemRefInfo linearizedInfo; std::tie(linearizedInfo, linearizedIndices) = memref::getLinearizedMemRefOffsetAndSize( - rewriter, loc, srcBits, dstBits, + rewriter, loc, oldBits, newBits, stridedMetadata.getConstifiedMixedOffset(), stridedMetadata.getConstifiedMixedSizes(), stridedMetadata.getConstifiedMixedStrides(), getAsOpFoldResult(adaptor.getIndices())); std::optional foldedIntraVectorOffset = - isUnalignedEmulation - ? getConstantIntValue(linearizedInfo.intraDataOffset) - : 0; + isFullyAligned ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); - int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); - auto numElements = - llvm::divideCeil(maxIntraDataOffset + origElements, scale); + int64_t maxIntraDataOffset = + foldedIntraVectorOffset.value_or(elementsPerContainerType - 1); + auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements, + elementsPerContainerType); auto newRead = rewriter.create( loc, VectorType::get(numElements, newElementType), adaptor.getSource(), @@ -802,7 +853,9 @@ struct ConvertVectorTransferRead final newPadding); auto bitCast = rewriter.create( - loc, VectorType::get(numElements * scale, oldElementType), newRead); + loc, + VectorType::get(numElements * elementsPerContainerType, oldElementType), + newRead); Value result = bitCast->getResult(0); if (!foldedIntraVectorOffset) { @@ -811,7 +864,7 @@ struct ConvertVectorTransferRead final result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros, linearizedInfo.intraDataOffset, origElements); - } else if (isUnalignedEmulation) { + } else if (!isFullyAligned) { result = staticallyExtractSubvector(rewriter, loc, op.getType(), result, *foldedIntraVectorOffset, origElements); @@ -1069,41 +1122,76 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter, return commonConversionPrecondition(rewriter, preconditionType, op); } -/// Verify that `subByteVecType` and `dstType` are aligned. Alignment -/// means that: -/// 1. The `dstType` element type is a multiple of the -/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8 -/// is not supported). Let this multiple be `N`. -/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a -/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is -/// not supported). +/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned. +/// +/// Alignment means that `subByteVecTy` can be packed into a vector of +/// `containerTy` elements. More specifically: +/// 1. The bit-width of `containerTy` is a multiple of the +/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16` +/// this multiple is 4. +/// 2. The multiple from 1. above divides evenly the number of the (trailing) +/// elements in `subByteVecTy`. +/// +/// EXAMPLE 1: +/// `subByteVecTy = vector<2xi4>`, and +/// `containerTy = i16` +/// +/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_. +/// +/// EXAMPLE 2: +/// `subByteVecTy = vector<3xi4>`, and +/// `containerTy = i16` +/// +/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_. +/// +/// EXAMPLE 3: +/// `subByteVecTy = vector<3xi3>`, and +/// `containerTy = i16` +/// +/// 16 _is not_ a multiple of 3, hence the conditions are _not met_. /// /// NOTE: This method assumes that common conversion preconditions are met. In -/// particular, the element type of `dstType` is assumed to be a multi-byte -/// type (e.g. i8, i16, i32). +/// particular, `containerTy` is assumed to be a +/// multi-byte scalar type (e.g., i8, i16, i32). static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter, - VectorType subByteVecType, - VectorType dstType, + VectorType subByteVecTy, + Type containerTy, Operation *op) { - if (!subByteVecType || !dstType) - return rewriter.notifyMatchFailure(op, "Not a supported aligned case"); - unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth(); - unsigned dstElemBitwidth = dstType.getElementTypeBitWidth(); + // TODO: This is validating the inputs rather than checking the conditions + // documented above. Replace with an assert. + if (!subByteVecTy) + return rewriter.notifyMatchFailure(op, "not a vector!"); - if (dstElemBitwidth < 8) - return rewriter.notifyMatchFailure( - op, "the bitwidth of dstType must be greater than or equal to 8"); - if (dstElemBitwidth % srcElemBitwidth != 0) - return rewriter.notifyMatchFailure(op, "unaligned cases are not supported"); - if (srcElemBitwidth != 2 && srcElemBitwidth != 4) + // TODO: This is validating the inputs rather than checking the conditions + // documented above. Replace with an assert. + if (!containerTy.isIntOrFloat()) + return rewriter.notifyMatchFailure(op, "not a scalar!"); + + unsigned subByteBits = subByteVecTy.getElementTypeBitWidth(); + unsigned multiByteBits = containerTy.getIntOrFloatBitWidth(); + + // Enforced by the common pre-conditions. + assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!"); + + // TODO: Remove this condition - the assert above (and + // commonConversionPrecondtion) takes care of that. + if (multiByteBits < 8) + return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!"); + + // TODO: Add support other widths (when/if needed) + if (subByteBits != 2 && subByteBits != 4) return rewriter.notifyMatchFailure( - op, "only src bitwidth of 2 or 4 is supported at this moment"); + op, "only 2-bit and 4-bit sub-byte type is supported at this moment"); + + // Condition 1. + if (multiByteBits % subByteBits != 0) + return rewriter.notifyMatchFailure(op, "unalagined element types"); - const int numSrcElemsPerByte = 8 / srcElemBitwidth; - if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0) + // Condition 2. + if (!isSubByteVecFittable(subByteVecTy, containerTy)) return rewriter.notifyMatchFailure( - op, "the trailing dimension of the input vector of sub-bytes must be a " - "multiple of 8 / "); + op, "not possible to fit this sub-byte vector type into a vector of " + "the given multi-byte type"); return success(); } @@ -1495,33 +1583,34 @@ struct RewriteExtOfBitCast : OpRewritePattern { /// LLVM to scramble with peephole optimizations. Templated to choose between /// signed and unsigned conversions. /// -/// For example (signed): +/// EXAMPLE 1 (signed): /// arith.extsi %in : vector<8xi4> to vector<8xi32> -/// is rewriten as -/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> -/// %1 = arith.shli %0, 4 : vector<4xi8> -/// %2 = arith.shrsi %1, 4 : vector<4xi8> -/// %3 = arith.shrsi %0, 4 : vector<4xi8> -/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8> -/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32> +/// is rewriten as: +/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> +/// %1 = arith.shli %0, 4 : vector<4xi8> +/// %2 = arith.shrsi %1, 4 : vector<4xi8> +/// %3 = arith.shrsi %0, 4 : vector<4xi8> +/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8> +/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32> /// +/// EXAMPLE 2 (fp): /// arith.sitofp %in : vector<8xi4> to vector<8xf32> -/// is rewriten as -/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> -/// %1 = arith.shli %0, 4 : vector<4xi8> -/// %2 = arith.shrsi %1, 4 : vector<4xi8> -/// %3 = arith.shrsi %0, 4 : vector<4xi8> -/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8> -/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32> +/// is rewriten as: +/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> +/// %1 = arith.shli %0, 4 : vector<4xi8> +/// %2 = arith.shrsi %1, 4 : vector<4xi8> +/// %3 = arith.shrsi %0, 4 : vector<4xi8> +/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8> +/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32> /// -/// Example (unsigned): +/// EXAMPLE 3 (unsigned): /// arith.extui %in : vector<8xi4> to vector<8xi32> -/// is rewritten as -/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> -/// %1 = arith.andi %0, 15 : vector<4xi8> -/// %2 = arith.shrui %0, 4 : vector<4xi8> -/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8> -/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32> +/// is rewritten as: +/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> +/// %1 = arith.andi %0, 15 : vector<4xi8> +/// %2 = arith.shrui %0, 4 : vector<4xi8> +/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8> +/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32> /// template struct RewriteAlignedSubByteIntExt : OpRewritePattern { @@ -1531,16 +1620,17 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern { PatternRewriter &rewriter) const override { // Verify the preconditions. Value srcValue = conversionOp.getIn(); - auto srcVecType = dyn_cast(srcValue.getType()); - auto dstVecType = dyn_cast(conversionOp.getType()); + VectorType srcVecType = dyn_cast(srcValue.getType()); + VectorType dstVecType = dyn_cast(conversionOp.getType()); if (failed( commonConversionPrecondition(rewriter, dstVecType, conversionOp))) return failure(); // Check general alignment preconditions. - if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType, - conversionOp))) + Type containerType = rewriter.getI8Type(); + if (failed(alignedConversionPrecondition(rewriter, srcVecType, + containerType, conversionOp))) return failure(); // Perform the rewrite. @@ -1572,15 +1662,16 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern { /// /// For example: /// arith.trunci %in : vector<8xi32> to vector<8xi4> -/// is rewriten as /// -/// %cst = arith.constant dense<15> : vector<4xi8> -/// %cst_0 = arith.constant dense<4> : vector<4xi8> -/// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8> -/// %2 = arith.andi %0, %cst : vector<4xi8> -/// %3 = arith.shli %1, %cst_0 : vector<4xi8> -/// %4 = arith.ori %2, %3 : vector<4xi8> -/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4> +/// is rewriten as: +/// +/// %cst = arith.constant dense<15> : vector<4xi8> +/// %cst_0 = arith.constant dense<4> : vector<4xi8> +/// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8> +/// %2 = arith.andi %0, %cst : vector<4xi8> +/// %3 = arith.shli %1, %cst_0 : vector<4xi8> +/// %4 = arith.ori %2, %3 : vector<4xi8> +/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4> /// struct RewriteAlignedSubByteIntTrunc : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1603,8 +1694,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern { // Check general alignment preconditions. We invert the src/dst type order // to reuse the existing precondition logic. - if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType, - truncOp))) + Type containerType = rewriter.getI8Type(); + if (failed(alignedConversionPrecondition(rewriter, dstVecType, + containerType, truncOp))) return failure(); // Create a new iX -> i8 truncation op. @@ -1624,10 +1716,11 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern { /// Rewrite a sub-byte vector transpose into a sequence of instructions that /// perform the transpose on wider (byte) element types. -/// For example: +/// +/// EXAMPLE: /// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4> /// -/// is rewritten as: +/// is rewritten as: /// /// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8> /// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8> @@ -1675,6 +1768,7 @@ struct RewriteVectorTranspose : OpRewritePattern { // Public Interface Definition //===----------------------------------------------------------------------===// +// The emulated type is inferred from the converted memref type. void vector::populateVectorNarrowTypeEmulationPatterns( const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns) { @@ -1687,22 +1781,26 @@ void vector::populateVectorNarrowTypeEmulationPatterns( void vector::populateVectorNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit) { + // TODO: Document what the emulated type is. patterns.add, RewriteExtOfBitCast>(patterns.getContext(), benefit); // Patterns for aligned cases. We set higher priority as they are expected to // generate better performance for aligned cases. + // The emulated type is always i8. patterns.add, RewriteAlignedSubByteIntExt, RewriteAlignedSubByteIntTrunc>(patterns.getContext(), benefit.getBenefit() + 1); + // The emulated type is always i8. patterns .add, RewriteAlignedSubByteIntExt>( patterns.getContext(), benefit.getBenefit() + 1); } +// The emulated type is always i8. void vector::populateVectorTransposeNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit);