-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[mlir][Vector] Support efficient shape cast lowering for n-D vectors #123497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,40 +11,41 @@ | |
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include "mlir/Dialect/Affine/IR/AffineOps.h" | ||
| #include "mlir/Dialect/Arith/IR/Arith.h" | ||
| #include "mlir/Dialect/Arith/Utils/Utils.h" | ||
| #include "mlir/Dialect/Linalg/IR/Linalg.h" | ||
| #include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
| #include "mlir/Dialect/SCF/IR/SCF.h" | ||
| #include "mlir/Dialect/Tensor/IR/Tensor.h" | ||
| #include "mlir/Dialect/Utils/IndexingUtils.h" | ||
| #include "mlir/Dialect/Utils/StructuredOpsUtils.h" | ||
| #include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
| #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" | ||
| #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" | ||
| #include "mlir/Dialect/Vector/Utils/VectorUtils.h" | ||
| #include "mlir/IR/BuiltinAttributeInterfaces.h" | ||
| #include "mlir/IR/BuiltinTypes.h" | ||
| #include "mlir/IR/ImplicitLocOpBuilder.h" | ||
| #include "mlir/IR/Location.h" | ||
| #include "mlir/IR/Matchers.h" | ||
| #include "mlir/IR/PatternMatch.h" | ||
| #include "mlir/IR/TypeUtilities.h" | ||
| #include "mlir/Interfaces/VectorInterfaces.h" | ||
|
|
||
| #define DEBUG_TYPE "vector-shape-cast-lowering" | ||
|
|
||
| using namespace mlir; | ||
| using namespace mlir::vector; | ||
|
|
||
| /// Increments n-D `indices` by `step` starting from the innermost dimension. | ||
| static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType, | ||
| int step = 1) { | ||
| for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) { | ||
| indices[dim] += step; | ||
| if (indices[dim] < vecType.getDimSize(dim)) | ||
| break; | ||
|
|
||
| indices[dim] = 0; | ||
| step = 1; | ||
| } | ||
| } | ||
|
|
||
| namespace { | ||
| /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D | ||
| /// vectors progressively on the way to target llvm.matrix intrinsics. | ||
| /// This iterates over the most major dimension of the 2-D vector and performs | ||
| /// rewrites into: | ||
| /// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D | ||
| class ShapeCastOp2DDownCastRewritePattern | ||
| /// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D | ||
| /// vectors progressively. This iterates over the n-1 major dimensions of the | ||
| /// n-D vector and performs rewrites into: | ||
| /// vector.extract from n-D + vector.insert_strided_slice offset into 1-D | ||
| class ShapeCastOpNDDownCastRewritePattern | ||
| : public OpRewritePattern<vector::ShapeCastOp> { | ||
| public: | ||
| using OpRewritePattern::OpRewritePattern; | ||
|
|
@@ -53,35 +54,52 @@ class ShapeCastOp2DDownCastRewritePattern | |
| PatternRewriter &rewriter) const override { | ||
| auto sourceVectorType = op.getSourceVectorType(); | ||
| auto resultVectorType = op.getResultVectorType(); | ||
|
|
||
| if (sourceVectorType.isScalable() || resultVectorType.isScalable()) | ||
| return failure(); | ||
|
|
||
| if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) | ||
| int64_t srcRank = sourceVectorType.getRank(); | ||
| int64_t resRank = resultVectorType.getRank(); | ||
| if (srcRank < 2 || resRank != 1) | ||
| return failure(); | ||
|
|
||
| // Compute the number of 1-D vector elements involved in the reshape. | ||
| int64_t numElts = 1; | ||
| for (int64_t dim = 0; dim < srcRank - 1; ++dim) | ||
| numElts *= sourceVectorType.getDimSize(dim); | ||
|
|
||
| auto loc = op.getLoc(); | ||
| Value desc = rewriter.create<arith::ConstantOp>( | ||
| SmallVector<int64_t> srcIdx(srcRank - 1); | ||
| SmallVector<int64_t> resIdx(resRank); | ||
|
||
| int64_t extractSize = sourceVectorType.getShape().back(); | ||
| Value result = rewriter.create<arith::ConstantOp>( | ||
| loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); | ||
| unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; | ||
| for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { | ||
| Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i); | ||
| desc = rewriter.create<vector::InsertStridedSliceOp>( | ||
| loc, vec, desc, | ||
| /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); | ||
|
|
||
| // Compute the indices of each 1-D vector element of the source extraction | ||
| // and destination slice insertion and generate such instructions. | ||
| for (int64_t i = 0; i < numElts; ++i) { | ||
| if (i != 0) { | ||
| incIdx(srcIdx, sourceVectorType, /*step=*/1); | ||
| incIdx(resIdx, resultVectorType, /*step=*/extractSize); | ||
| } | ||
|
|
||
| Value extract = | ||
| rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx); | ||
| result = rewriter.create<vector::InsertStridedSliceOp>( | ||
| loc, extract, result, | ||
| /*offsets=*/resIdx, /*strides=*/1); | ||
| } | ||
| rewriter.replaceOp(op, desc); | ||
|
|
||
| rewriter.replaceOp(op, result); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| /// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D | ||
| /// vectors progressively. | ||
| /// This iterates over the most major dimension of the 2-D vector and performs | ||
| /// rewrites into: | ||
| /// vector.extract_strided_slice from 1-D + vector.insert into 2-D | ||
| /// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D | ||
| /// vectors progressively. This iterates over the n-1 major dimension of the n-D | ||
| /// vector and performs rewrites into: | ||
| /// vector.extract_strided_slice from 1-D + vector.insert into n-D | ||
| /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. | ||
| class ShapeCastOp2DUpCastRewritePattern | ||
| class ShapeCastOpNDUpCastRewritePattern | ||
| : public OpRewritePattern<vector::ShapeCastOp> { | ||
| public: | ||
| using OpRewritePattern::OpRewritePattern; | ||
|
|
@@ -90,43 +108,43 @@ class ShapeCastOp2DUpCastRewritePattern | |
| PatternRewriter &rewriter) const override { | ||
| auto sourceVectorType = op.getSourceVectorType(); | ||
| auto resultVectorType = op.getResultVectorType(); | ||
|
|
||
| if (sourceVectorType.isScalable() || resultVectorType.isScalable()) | ||
| return failure(); | ||
|
|
||
| if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) | ||
| int64_t srcRank = sourceVectorType.getRank(); | ||
| int64_t resRank = resultVectorType.getRank(); | ||
| if (srcRank != 1 || resRank < 2) | ||
| return failure(); | ||
|
|
||
| // Compute the number of 1-D vector elements involved in the reshape. | ||
| int64_t numElts = 1; | ||
| for (int64_t dim = 0; dim < resRank - 1; ++dim) | ||
| numElts *= resultVectorType.getDimSize(dim); | ||
|
|
||
| // Compute the indices of each 1-D vector element of the source slice | ||
| // extraction and destination insertion and generate such instructions. | ||
| auto loc = op.getLoc(); | ||
| Value desc = rewriter.create<arith::ConstantOp>( | ||
| SmallVector<int64_t> srcIdx(srcRank); | ||
| SmallVector<int64_t> resIdx(resRank - 1); | ||
|
||
| int64_t extractSize = resultVectorType.getShape().back(); | ||
| Value result = rewriter.create<arith::ConstantOp>( | ||
| loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); | ||
| unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; | ||
| for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { | ||
| Value vec = rewriter.create<vector::ExtractStridedSliceOp>( | ||
| loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize, | ||
| /*sizes=*/mostMinorVectorSize, | ||
| for (int64_t i = 0; i < numElts; ++i) { | ||
| if (i != 0) { | ||
| incIdx(srcIdx, sourceVectorType, /*step=*/extractSize); | ||
| incIdx(resIdx, resultVectorType, /*step=*/1); | ||
| } | ||
|
|
||
| Value extract = rewriter.create<vector::ExtractStridedSliceOp>( | ||
| loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize, | ||
| /*strides=*/1); | ||
| desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i); | ||
| result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx); | ||
| } | ||
| rewriter.replaceOp(op, desc); | ||
| rewriter.replaceOp(op, result); | ||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp, | ||
| int dimIdx, int initialStep = 1) { | ||
| int step = initialStep; | ||
| for (int d = dimIdx; d >= 0; d--) { | ||
| idx[d] += step; | ||
| if (idx[d] >= tp.getDimSize(d)) { | ||
| idx[d] = 0; | ||
| step = 1; | ||
| } else { | ||
| break; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| // We typically should not lower general shape cast operations into data | ||
| // movement instructions, since the assumption is that these casts are | ||
| // optimized away during progressive lowering. For completeness, however, | ||
|
|
@@ -145,18 +163,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { | |
| if (sourceVectorType.isScalable() || resultVectorType.isScalable()) | ||
| return failure(); | ||
|
|
||
| // Special case 2D / 1D lowerings with better implementations. | ||
| // TODO: make is ND / 1D to allow generic ND -> 1D -> MD. | ||
| // Special case for n-D / 1-D lowerings with better implementations. | ||
| int64_t srcRank = sourceVectorType.getRank(); | ||
| int64_t resRank = resultVectorType.getRank(); | ||
| if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) | ||
| if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1)) | ||
| return failure(); | ||
|
|
||
| // Generic ShapeCast lowering path goes all the way down to unrolled scalar | ||
| // extract/insert chains. | ||
| // TODO: consider evolving the semantics to only allow 1D source or dest and | ||
| // drop this potentially very expensive lowering. | ||
| // Compute number of elements involved in the reshape. | ||
| int64_t numElts = 1; | ||
| for (int64_t r = 0; r < srcRank; r++) | ||
| numElts *= sourceVectorType.getDimSize(r); | ||
|
|
@@ -172,8 +186,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { | |
| loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); | ||
| for (int64_t i = 0; i < numElts; i++) { | ||
| if (i != 0) { | ||
| incIdx(srcIdx, sourceVectorType, srcRank - 1); | ||
| incIdx(resIdx, resultVectorType, resRank - 1); | ||
| incIdx(srcIdx, sourceVectorType); | ||
| incIdx(resIdx, resultVectorType); | ||
| } | ||
|
|
||
| Value extract; | ||
|
|
@@ -252,7 +266,7 @@ class ScalableShapeCastOpRewritePattern | |
| // have a single trailing scalable dimension. This is because there are no | ||
| // legal representation of other scalable types in LLVM (and likely won't be | ||
| // soon). There are also (currently) no operations that can index or extract | ||
| // from >= 2D scalable vectors or scalable vectors of fixed vectors. | ||
| // from >= 2-D scalable vectors or scalable vectors of fixed vectors. | ||
| if (!isTrailingDimScalable(sourceVectorType) || | ||
| !isTrailingDimScalable(resultVectorType)) { | ||
| return failure(); | ||
|
|
@@ -334,8 +348,8 @@ class ScalableShapeCastOpRewritePattern | |
|
|
||
| // 4. Increment the insert/extract indices, stepping by minExtractionSize | ||
| // for the trailing dimensions. | ||
| incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize); | ||
| incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize); | ||
| incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize); | ||
| incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize); | ||
| } | ||
|
|
||
| rewriter.replaceOp(op, result); | ||
|
|
@@ -352,8 +366,8 @@ class ScalableShapeCastOpRewritePattern | |
|
|
||
| void mlir::vector::populateVectorShapeCastLoweringPatterns( | ||
| RewritePatternSet &patterns, PatternBenefit benefit) { | ||
| patterns.add<ShapeCastOp2DDownCastRewritePattern, | ||
| ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern, | ||
| patterns.add<ShapeCastOpNDDownCastRewritePattern, | ||
| ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern, | ||
| ScalableShapeCastOpRewritePattern>(patterns.getContext(), | ||
| benefit); | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we add an assertion that
indices[dim] == vecType.getDimSize(dim)? It looks weird to me when it happens. Assertion is a sanity check in this case.