From aaeb0fb646105b5af5b9d1841a49120c39f03b9d Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Sun, 2 Feb 2025 15:36:33 +0000 Subject: [PATCH 1/2] [mlir][Vector] Update VectorEmulateNarrowType.cpp (2/N) This is PR 2 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. This PR renames the variable "scale". Note, "scale" could mean either: * "original-elements-per-emulated-type", or * "emulated-elements-per-original-type". While from the context it is clear that it's always the former (original type is always a sub-byte type and the emulated type is usually `i8`), this PR reduces the cognitive load by making this clear. **DEPENDS ON:** * #123526 123526 Please only review the [top commit](https://github.com/llvm/llvm-project/pull/123527/commits/d40b31bb098e874be488182050c68b887e8d091a). **GitHub issue to track this work**: https://github.com/llvm/llvm-project/issues/123630 --- .../Transforms/VectorEmulateNarrowType.cpp | 78 +++++++++++-------- 1 file changed, 45 insertions(+), 33 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 0d310dc8be2fe..831c1ab736105 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -290,13 +290,15 @@ static VectorValue emulatedVectorLoad(OpBuilder &rewriter, Location loc, int64_t numContainerElemsToLoad, Type emulatedElemTy, Type containerElemTy) { - auto scale = containerElemTy.getIntOrFloatBitWidth() / - emulatedElemTy.getIntOrFloatBitWidth(); + auto emulatedPerContainerElem = containerElemTy.getIntOrFloatBitWidth() / + emulatedElemTy.getIntOrFloatBitWidth(); auto newLoad = rewriter.create( loc, VectorType::get(numContainerElemsToLoad, containerElemTy), base, getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices)); return rewriter.create( - loc, VectorType::get(numContainerElemsToLoad * scale, emulatedElemTy), + loc, + VectorType::get(numContainerElemsToLoad * emulatedPerContainerElem, + emulatedElemTy), newLoad); } @@ -388,10 +390,11 @@ static Value extractSliceIntoByte(ConversionPatternRewriter &rewriter, "sliceNumElements * vector element size must be less than or equal to 8"); assert(8 % vectorElementType.getIntOrFloatBitWidth() == 0 && "vector element must be a valid sub-byte type"); - auto scale = 8 / vectorElementType.getIntOrFloatBitWidth(); + auto emulatedPerContainerElem = 8 / vectorElementType.getIntOrFloatBitWidth(); auto emptyByteVector = rewriter.create( - loc, VectorType::get({scale}, vectorElementType), - rewriter.getZeroAttr(VectorType::get({scale}, vectorElementType))); + loc, VectorType::get({emulatedPerContainerElem}, vectorElementType), + rewriter.getZeroAttr( + VectorType::get({emulatedPerContainerElem}, vectorElementType))); auto extracted = staticallyExtractSubvector(rewriter, loc, vector, extractOffset, sliceNumElements); return staticallyInsertSubvector(rewriter, loc, extracted, emptyByteVector, @@ -656,9 +659,9 @@ struct ConvertVectorMaskedStore final "(bit-wise misalignment)"); } - int scale = containerBits / emulatedBits; + int emulatedPerContainerElem = containerBits / emulatedBits; int origElements = op.getValueToStore().getType().getNumElements(); - if (origElements % scale != 0) + if (origElements % emulatedPerContainerElem != 0) return failure(); auto stridedMetadata = @@ -707,12 +710,13 @@ struct ConvertVectorMaskedStore final // // FIXME: Make an example based on the comment above work (see #115460 for // reproducer). - FailureOr newMask = - getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, scale); + FailureOr newMask = getCompressedMaskOp( + rewriter, loc, op.getMask(), origElements, emulatedPerContainerElem); if (failed(newMask)) return failure(); - auto numElements = (origElements + scale - 1) / scale; + auto numElements = (origElements + emulatedPerContainerElem - 1) / + emulatedPerContainerElem; auto newType = VectorType::get(numElements, containerElemTy); auto passThru = rewriter.create( loc, newType, rewriter.getZeroAttr(newType)); @@ -721,7 +725,8 @@ struct ConvertVectorMaskedStore final loc, newType, adaptor.getBase(), linearizedIndices, newMask.value()->getResult(0), passThru); - auto newBitCastType = VectorType::get(numElements * scale, emulatedElemTy); + auto newBitCastType = + VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy); Value valueToStore = rewriter.create(loc, newBitCastType, newLoad); valueToStore = rewriter.create( @@ -765,7 +770,7 @@ struct ConvertVectorLoad final : OpConversionPattern { op, "impossible to pack emulated elements into container elements " "(bit-wise misalignment)"); } - int scale = containerBits / emulatedBits; + int emulatedPerContainerElem = containerBits / emulatedBits; // Adjust the number of elements to load when emulating narrow types, // and then cast back to the original type with vector.bitcast op. @@ -797,7 +802,7 @@ struct ConvertVectorLoad final : OpConversionPattern { // compile time as they must be constants. auto origElements = op.getVectorType().getNumElements(); - bool isAlignedEmulation = origElements % scale == 0; + bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); @@ -818,9 +823,10 @@ struct ConvertVectorLoad final : OpConversionPattern { : getConstantIntValue(linearizedInfo.intraDataOffset); // Always load enough elements which can cover the original elements. - int64_t maxintraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); - auto numElements = - llvm::divideCeil(maxintraDataOffset + origElements, scale); + int64_t maxintraDataOffset = + foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1); + auto numElements = llvm::divideCeil(maxintraDataOffset + origElements, + emulatedPerContainerElem); Value result = emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices, numElements, emulatedElemTy, containerElemTy); @@ -870,7 +876,7 @@ struct ConvertVectorMaskedLoad final op, "impossible to pack emulated elements into container elements " "(bit-wise misalignment)"); } - int scale = containerBits / emulatedBits; + int emulatedPerContainerElem = containerBits / emulatedBits; // Adjust the number of elements to load when emulating narrow types, // and then cast back to the original type with vector.bitcast op. @@ -916,7 +922,7 @@ struct ConvertVectorMaskedLoad final // subvector at the proper offset after bit-casting. auto origType = op.getVectorType(); auto origElements = origType.getNumElements(); - bool isAlignedEmulation = origElements % scale == 0; + bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); @@ -935,18 +941,21 @@ struct ConvertVectorMaskedLoad final ? 0 : getConstantIntValue(linearizedInfo.intraDataOffset); - int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); - FailureOr newMask = getCompressedMaskOp( - rewriter, loc, op.getMask(), origElements, scale, maxIntraDataOffset); + int64_t maxIntraDataOffset = + foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1); + FailureOr newMask = + getCompressedMaskOp(rewriter, loc, op.getMask(), origElements, + emulatedPerContainerElem, maxIntraDataOffset); if (failed(newMask)) return failure(); Value passthru = op.getPassThru(); - auto numElements = - llvm::divideCeil(maxIntraDataOffset + origElements, scale); + auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements, + emulatedPerContainerElem); auto loadType = VectorType::get(numElements, containerElemTy); - auto newBitcastType = VectorType::get(numElements * scale, emulatedElemTy); + auto newBitcastType = + VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy); auto emptyVector = rewriter.create( loc, newBitcastType, rewriter.getZeroAttr(newBitcastType)); @@ -973,8 +982,8 @@ struct ConvertVectorMaskedLoad final rewriter.create(loc, newBitcastType, newLoad); Value mask = op.getMask(); - auto newSelectMaskType = - VectorType::get(numElements * scale, rewriter.getI1Type()); + auto newSelectMaskType = VectorType::get( + numElements * emulatedPerContainerElem, rewriter.getI1Type()); // TODO: try to fold if op's mask is constant auto emptyMask = rewriter.create( loc, newSelectMaskType, rewriter.getZeroAttr(newSelectMaskType)); @@ -1033,11 +1042,11 @@ struct ConvertVectorTransferRead final op, "impossible to pack emulated elements into container elements " "(bit-wise misalignment)"); } - int scale = containerBits / emulatedBits; + int emulatedPerContainerElem = containerBits / emulatedBits; auto origElements = op.getVectorType().getNumElements(); - bool isAlignedEmulation = origElements % scale == 0; + bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0; auto newPadding = rewriter.create(loc, containerElemTy, adaptor.getPadding()); @@ -1060,9 +1069,10 @@ struct ConvertVectorTransferRead final ? 0 : getConstantIntValue(linearizedInfo.intraDataOffset); - int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(scale - 1); - auto numElements = - llvm::divideCeil(maxIntraDataOffset + origElements, scale); + int64_t maxIntraDataOffset = + foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1); + auto numElements = llvm::divideCeil(maxIntraDataOffset + origElements, + emulatedPerContainerElem); auto newRead = rewriter.create( loc, VectorType::get(numElements, containerElemTy), adaptor.getSource(), @@ -1070,7 +1080,9 @@ struct ConvertVectorTransferRead final newPadding); auto bitCast = rewriter.create( - loc, VectorType::get(numElements * scale, emulatedElemTy), newRead); + loc, + VectorType::get(numElements * emulatedPerContainerElem, emulatedElemTy), + newRead); Value result = bitCast->getResult(0); if (!foldedIntraVectorOffset) { From 95f8ad113145083846177a599f3d1e4b6fcaeab1 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Fri, 17 Jan 2025 13:54:34 +0000 Subject: [PATCH 2/2] [mlir][Vector] Update VectorEmulateNarrowType.cpp (3/N) This is PR 3 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. 1. Replaces `isUnalignedEmulation` with `isFullyAligned` Note, `isUnalignedEmulation` is always computed following a "per-element-alignment" condition: ```cpp // Check per-element alignment. if (containerBits % emulatedBits != 0) { return rewriter.notifyMatchFailure( op, "impossible to pack emulated elements into container elements " "(bit-wise misalignment)"); } // (...) bool isUnalignedEmulation = origElements % emulatedPerContainerElem != 0; ``` Given that `isUnalignedEmulation` captures only one of two conditions required for "full alignment", it should be re-named as `isPartiallyUnalignedEmulation`. Instead, I've flipped the condition and renamed it as `isFullyAligned`: ```cpp bool isFullyAligned = origElements % emulatedPerContainerElem == 0; ``` 2. In addition: * Unifies various comments throughout the file (for consistency). * Adds new comments throughout the file and adds TODOs where high-level comments are missing. --- .../Transforms/VectorEmulateNarrowType.cpp | 111 ++++++++++-------- 1 file changed, 64 insertions(+), 47 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 831c1ab736105..28ccbfbb6962e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -48,6 +48,10 @@ using namespace mlir; using VectorValue = TypedValue; using MemRefValue = TypedValue; +//===----------------------------------------------------------------------===// +// Utils +//===----------------------------------------------------------------------===// + /// Returns a compressed mask for the emulated vector. For example, when /// emulating an eight-element `i8` vector with `i32` (i.e. when the source /// elements span two dest elements), this method compresses `vector<8xi1>` @@ -407,6 +411,7 @@ namespace { // ConvertVectorStore //===----------------------------------------------------------------------===// +// TODO: Document-me struct ConvertVectorStore final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -632,6 +637,7 @@ struct ConvertVectorStore final : OpConversionPattern { // ConvertVectorMaskedStore //===----------------------------------------------------------------------===// +// TODO: Document-me struct ConvertVectorMaskedStore final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -745,6 +751,7 @@ struct ConvertVectorMaskedStore final // ConvertVectorLoad //===----------------------------------------------------------------------===// +// TODO: Document-me struct ConvertVectorLoad final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -802,7 +809,8 @@ struct ConvertVectorLoad final : OpConversionPattern { // compile time as they must be constants. auto origElements = op.getVectorType().getNumElements(); - bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0; + // Note, per-element-alignment was already verified above. + bool isFullyAligned = origElements % emulatedPerContainerElem == 0; auto stridedMetadata = rewriter.create(loc, op.getBase()); @@ -818,9 +826,8 @@ struct ConvertVectorLoad final : OpConversionPattern { getAsOpFoldResult(adaptor.getIndices())); std::optional foldedIntraVectorOffset = - isAlignedEmulation - ? 0 - : getConstantIntValue(linearizedInfo.intraDataOffset); + isFullyAligned ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); // Always load enough elements which can cover the original elements. int64_t maxintraDataOffset = @@ -834,10 +841,10 @@ struct ConvertVectorLoad final : OpConversionPattern { if (!foldedIntraVectorOffset) { auto resultVector = rewriter.create( loc, op.getType(), rewriter.getZeroAttr(op.getType())); - result = dynamicallyExtractSubVector(rewriter, loc, result, resultVector, - linearizedInfo.intraDataOffset, - origElements); - } else if (!isAlignedEmulation) { + result = dynamicallyExtractSubVector( + rewriter, loc, dyn_cast>(result), resultVector, + linearizedInfo.intraDataOffset, origElements); + } else if (!isFullyAligned) { result = staticallyExtractSubvector( rewriter, loc, result, *foldedIntraVectorOffset, origElements); } @@ -850,6 +857,7 @@ struct ConvertVectorLoad final : OpConversionPattern { // ConvertVectorMaskedLoad //===----------------------------------------------------------------------===// +// TODO: Document-me struct ConvertVectorMaskedLoad final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1016,6 +1024,7 @@ struct ConvertVectorMaskedLoad final // ConvertVectorTransferRead //===----------------------------------------------------------------------===// +// TODO: Document-me struct ConvertVectorTransferRead final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1046,7 +1055,8 @@ struct ConvertVectorTransferRead final auto origElements = op.getVectorType().getNumElements(); - bool isAlignedEmulation = origElements % emulatedPerContainerElem == 0; + // Note, per-element-alignment was already verified above. + bool isFullyAligned = origElements % emulatedPerContainerElem == 0; auto newPadding = rewriter.create(loc, containerElemTy, adaptor.getPadding()); @@ -1065,9 +1075,8 @@ struct ConvertVectorTransferRead final getAsOpFoldResult(adaptor.getIndices())); std::optional foldedIntraVectorOffset = - isAlignedEmulation - ? 0 - : getConstantIntValue(linearizedInfo.intraDataOffset); + isFullyAligned ? 0 + : getConstantIntValue(linearizedInfo.intraDataOffset); int64_t maxIntraDataOffset = foldedIntraVectorOffset.value_or(emulatedPerContainerElem - 1); @@ -1091,7 +1100,7 @@ struct ConvertVectorTransferRead final result = dynamicallyExtractSubVector(rewriter, loc, bitCast, zeros, linearizedInfo.intraDataOffset, origElements); - } else if (!isAlignedEmulation) { + } else if (!isFullyAligned) { result = staticallyExtractSubvector( rewriter, loc, result, *foldedIntraVectorOffset, origElements); } @@ -1774,33 +1783,34 @@ struct RewriteExtOfBitCast : OpRewritePattern { /// LLVM to scramble with peephole optimizations. Templated to choose between /// signed and unsigned conversions. /// -/// For example (signed): +/// EXAMPLE 1 (signed): /// arith.extsi %in : vector<8xi4> to vector<8xi32> -/// is rewriten as -/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> -/// %1 = arith.shli %0, 4 : vector<4xi8> -/// %2 = arith.shrsi %1, 4 : vector<4xi8> -/// %3 = arith.shrsi %0, 4 : vector<4xi8> -/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8> -/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32> +/// is rewriten as: +/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> +/// %1 = arith.shli %0, 4 : vector<4xi8> +/// %2 = arith.shrsi %1, 4 : vector<4xi8> +/// %3 = arith.shrsi %0, 4 : vector<4xi8> +/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8> +/// %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32> /// +/// EXAMPLE 2 (fp): /// arith.sitofp %in : vector<8xi4> to vector<8xf32> -/// is rewriten as -/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> -/// %1 = arith.shli %0, 4 : vector<4xi8> -/// %2 = arith.shrsi %1, 4 : vector<4xi8> -/// %3 = arith.shrsi %0, 4 : vector<4xi8> -/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8> -/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32> +/// is rewriten as: +/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> +/// %1 = arith.shli %0, 4 : vector<4xi8> +/// %2 = arith.shrsi %1, 4 : vector<4xi8> +/// %3 = arith.shrsi %0, 4 : vector<4xi8> +/// %4 = vector.interleave %2, %3 : vector<4xi8> -> vector<8xi8> +/// %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32> /// -/// Example (unsigned): +/// EXAMPLE 3 (unsigned): /// arith.extui %in : vector<8xi4> to vector<8xi32> -/// is rewritten as -/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> -/// %1 = arith.andi %0, 15 : vector<4xi8> -/// %2 = arith.shrui %0, 4 : vector<4xi8> -/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8> -/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32> +/// is rewritten as: +/// %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8> +/// %1 = arith.andi %0, 15 : vector<4xi8> +/// %2 = arith.shrui %0, 4 : vector<4xi8> +/// %3 = vector.interleave %1, %2 : vector<4xi8> -> vector<8xi8> +/// %4 = arith.extui %3 : vector<8xi8> to vector<8xi32> /// template struct RewriteAlignedSubByteIntExt : OpRewritePattern { @@ -1810,8 +1820,8 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern { PatternRewriter &rewriter) const override { // Verify the preconditions. Value srcValue = conversionOp.getIn(); - auto srcVecType = dyn_cast(srcValue.getType()); - auto dstVecType = dyn_cast(conversionOp.getType()); + VectorType srcVecType = dyn_cast(srcValue.getType()); + VectorType dstVecType = dyn_cast(conversionOp.getType()); if (failed( commonConversionPrecondition(rewriter, dstVecType, conversionOp))) @@ -1851,15 +1861,16 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern { /// /// For example: /// arith.trunci %in : vector<8xi32> to vector<8xi4> -/// is rewriten as /// -/// %cst = arith.constant dense<15> : vector<4xi8> -/// %cst_0 = arith.constant dense<4> : vector<4xi8> -/// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8> -/// %2 = arith.andi %0, %cst : vector<4xi8> -/// %3 = arith.shli %1, %cst_0 : vector<4xi8> -/// %4 = arith.ori %2, %3 : vector<4xi8> -/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4> +/// is rewriten as: +/// +/// %cst = arith.constant dense<15> : vector<4xi8> +/// %cst_0 = arith.constant dense<4> : vector<4xi8> +/// %0, %1 = vector.deinterleave %in : vector<8xi8>, vector<8xi8> +/// %2 = arith.andi %0, %cst : vector<4xi8> +/// %3 = arith.shli %1, %cst_0 : vector<4xi8> +/// %4 = arith.ori %2, %3 : vector<4xi8> +/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4> /// struct RewriteAlignedSubByteIntTrunc : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1903,10 +1914,11 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern { /// Rewrite a sub-byte vector transpose into a sequence of instructions that /// perform the transpose on wider (byte) element types. -/// For example: +/// +/// EXAMPLE: /// %0 = vector.transpose %a, [1, 0] : vector<8x16xi4> to vector<16x8xi4> /// -/// is rewritten as: +/// is rewritten as: /// /// %0 = arith.extsi %arg0 : vector<8x16xi4> to vector<8x16xi8> /// %1 = vector.transpose %0, [1, 0] : vector<8x16xi8> to vector<16x8xi8> @@ -1954,6 +1966,7 @@ struct RewriteVectorTranspose : OpRewritePattern { // Public Interface Definition //===----------------------------------------------------------------------===// +// The emulated type is inferred from the converted memref type. void vector::populateVectorNarrowTypeEmulationPatterns( const arith::NarrowTypeEmulationConverter &typeConverter, RewritePatternSet &patterns) { @@ -1966,22 +1979,26 @@ void vector::populateVectorNarrowTypeEmulationPatterns( void vector::populateVectorNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit) { + // TODO: Document what the emulated type is. patterns.add, RewriteExtOfBitCast>(patterns.getContext(), benefit); // Patterns for aligned cases. We set higher priority as they are expected to // generate better performance for aligned cases. + // The emulated type is always i8. patterns.add, RewriteAlignedSubByteIntExt, RewriteAlignedSubByteIntTrunc>(patterns.getContext(), benefit.getBenefit() + 1); + // The emulated type is always i8. patterns .add, RewriteAlignedSubByteIntExt>( patterns.getContext(), benefit.getBenefit() + 1); } +// The emulated type is always i8. void vector::populateVectorTransposeNarrowTypeRewritePatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit);