-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][Vector] Support efficient shape cast lowering for n-D vectors #123497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This PR implements a generalization of the existing efficient lowering of shape casts from 2-D to 1D and 1-D to 2-D vectors. This significantly reduces code size and generates more performant code for n-D shape casts that make their way to LLVM/SPIR-V.
|
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesThis PR implements a generalization of the existing more efficient lowering of shape casts from 2-D to 1D and 1-D to 2-D vectors. This significantly reduces code size and generates more performant code for n-D shape casts that make their way to LLVM/SPIR-V. Full diff: https://github.com/llvm/llvm-project/pull/123497.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 95ebd4e9fe3d99..edbf798e1c673b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -11,40 +11,41 @@
//
//===----------------------------------------------------------------------===//
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/Location.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/VectorInterfaces.h"
#define DEBUG_TYPE "vector-shape-cast-lowering"
using namespace mlir;
using namespace mlir::vector;
+/// Increments n-D `indices` by `step` starting from the innermost dimension.
+static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
+ int step = 1) {
+ for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
+ indices[dim] += step;
+ if (indices[dim] < vecType.getDimSize(dim))
+ break;
+
+ indices[dim] = 0;
+ step = 1;
+ }
+}
+
namespace {
-/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
-/// vectors progressively on the way to target llvm.matrix intrinsics.
-/// This iterates over the most major dimension of the 2-D vector and performs
-/// rewrites into:
-/// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
-class ShapeCastOp2DDownCastRewritePattern
+/// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
+/// vectors progressively. This iterates over the n-1 major dimensions of the
+/// n-D vector and performs rewrites into:
+/// vector.extract from n-D + vector.insert_strided_slice offset into 1-D
+class ShapeCastOpNDDownCastRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -53,35 +54,52 @@ class ShapeCastOp2DDownCastRewritePattern
PatternRewriter &rewriter) const override {
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
-
if (sourceVectorType.isScalable() || resultVectorType.isScalable())
return failure();
- if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
+ int64_t srcRank = sourceVectorType.getRank();
+ int64_t resRank = resultVectorType.getRank();
+ if (srcRank < 2 || resRank != 1)
return failure();
+ // Compute the number of 1-D vector elements involved in the reshape.
+ int64_t numElts = 1;
+ for (int64_t dim = 0; dim < srcRank - 1; ++dim)
+ numElts *= sourceVectorType.getDimSize(dim);
+
auto loc = op.getLoc();
- Value desc = rewriter.create<arith::ConstantOp>(
+ SmallVector<int64_t> srcIdx(srcRank - 1);
+ SmallVector<int64_t> resIdx(resRank);
+ int64_t extractSize = sourceVectorType.getShape().back();
+ Value result = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
- unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
- for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
- Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
- desc = rewriter.create<vector::InsertStridedSliceOp>(
- loc, vec, desc,
- /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
+
+ // Compute the indices of each 1-D vector element of the source extraction
+ // and destination slice insertion and generate such instructions.
+ for (int64_t i = 0; i < numElts; ++i) {
+ if (i != 0) {
+ incIdx(srcIdx, sourceVectorType, /*step=*/1);
+ incIdx(resIdx, resultVectorType, /*step=*/extractSize);
+ }
+
+ Value extract =
+ rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+ result = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, extract, result,
+ /*offsets=*/resIdx, /*strides=*/1);
}
- rewriter.replaceOp(op, desc);
+
+ rewriter.replaceOp(op, result);
return success();
}
};
-/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
-/// vectors progressively.
-/// This iterates over the most major dimension of the 2-D vector and performs
-/// rewrites into:
-/// vector.extract_strided_slice from 1-D + vector.insert into 2-D
+/// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
+/// vectors progressively. This iterates over the n-1 major dimension of the n-D
+/// vector and performs rewrites into:
+/// vector.extract_strided_slice from 1-D + vector.insert into n-D
/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
-class ShapeCastOp2DUpCastRewritePattern
+class ShapeCastOpNDUpCastRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -90,43 +108,43 @@ class ShapeCastOp2DUpCastRewritePattern
PatternRewriter &rewriter) const override {
auto sourceVectorType = op.getSourceVectorType();
auto resultVectorType = op.getResultVectorType();
-
if (sourceVectorType.isScalable() || resultVectorType.isScalable())
return failure();
- if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
+ int64_t srcRank = sourceVectorType.getRank();
+ int64_t resRank = resultVectorType.getRank();
+ if (srcRank != 1 || resRank < 2)
return failure();
+ // Compute the number of 1-D vector elements involved in the reshape.
+ int64_t numElts = 1;
+ for (int64_t dim = 0; dim < resRank - 1; ++dim)
+ numElts *= resultVectorType.getDimSize(dim);
+
+ // Compute the indices of each 1-D vector element of the source slice
+ // extraction and destination insertion and generate such instructions.
auto loc = op.getLoc();
- Value desc = rewriter.create<arith::ConstantOp>(
+ SmallVector<int64_t> srcIdx(srcRank);
+ SmallVector<int64_t> resIdx(resRank - 1);
+ int64_t extractSize = resultVectorType.getShape().back();
+ Value result = rewriter.create<arith::ConstantOp>(
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
- unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
- for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
- Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
- loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
- /*sizes=*/mostMinorVectorSize,
+ for (int64_t i = 0; i < numElts; ++i) {
+ if (i != 0) {
+ incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
+ incIdx(resIdx, resultVectorType, /*step=*/1);
+ }
+
+ Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize,
/*strides=*/1);
- desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
+ result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
}
- rewriter.replaceOp(op, desc);
+ rewriter.replaceOp(op, result);
return success();
}
};
-static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp,
- int dimIdx, int initialStep = 1) {
- int step = initialStep;
- for (int d = dimIdx; d >= 0; d--) {
- idx[d] += step;
- if (idx[d] >= tp.getDimSize(d)) {
- idx[d] = 0;
- step = 1;
- } else {
- break;
- }
- }
-}
-
// We typically should not lower general shape cast operations into data
// movement instructions, since the assumption is that these casts are
// optimized away during progressive lowering. For completeness, however,
@@ -145,18 +163,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
if (sourceVectorType.isScalable() || resultVectorType.isScalable())
return failure();
- // Special case 2D / 1D lowerings with better implementations.
- // TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
+ // Special case for n-D / 1-D lowerings with better implementations.
int64_t srcRank = sourceVectorType.getRank();
int64_t resRank = resultVectorType.getRank();
- if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
+ if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
return failure();
// Generic ShapeCast lowering path goes all the way down to unrolled scalar
// extract/insert chains.
- // TODO: consider evolving the semantics to only allow 1D source or dest and
- // drop this potentially very expensive lowering.
- // Compute number of elements involved in the reshape.
int64_t numElts = 1;
for (int64_t r = 0; r < srcRank; r++)
numElts *= sourceVectorType.getDimSize(r);
@@ -172,8 +186,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
for (int64_t i = 0; i < numElts; i++) {
if (i != 0) {
- incIdx(srcIdx, sourceVectorType, srcRank - 1);
- incIdx(resIdx, resultVectorType, resRank - 1);
+ incIdx(srcIdx, sourceVectorType);
+ incIdx(resIdx, resultVectorType);
}
Value extract;
@@ -252,7 +266,7 @@ class ScalableShapeCastOpRewritePattern
// have a single trailing scalable dimension. This is because there are no
// legal representation of other scalable types in LLVM (and likely won't be
// soon). There are also (currently) no operations that can index or extract
- // from >= 2D scalable vectors or scalable vectors of fixed vectors.
+ // from >= 2-D scalable vectors or scalable vectors of fixed vectors.
if (!isTrailingDimScalable(sourceVectorType) ||
!isTrailingDimScalable(resultVectorType)) {
return failure();
@@ -334,8 +348,8 @@ class ScalableShapeCastOpRewritePattern
// 4. Increment the insert/extract indices, stepping by minExtractionSize
// for the trailing dimensions.
- incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
- incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
+ incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
+ incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
}
rewriter.replaceOp(op, result);
@@ -352,8 +366,8 @@ class ScalableShapeCastOpRewritePattern
void mlir::vector::populateVectorShapeCastLoweringPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
- patterns.add<ShapeCastOp2DDownCastRewritePattern,
- ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
+ patterns.add<ShapeCastOpNDDownCastRewritePattern,
+ ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern,
ScalableShapeCastOpRewritePattern>(patterns.getContext(),
benefit);
}
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index f2f1211fd70eed..b4c52d5533116c 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
// CHECK-LABEL: func @nop_shape_cast
// CHECK-SAME: %[[A:.*]]: vector<16xf32>
@@ -82,19 +82,16 @@ func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
// CHECK-LABEL: func @shape_cast_3d1d
// CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32>
-// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32>
-// CHECK: return %[[T11]] : vector<6xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
+// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[C]]
+// CHECK-SAME: {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
+// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]]
+// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32>
+// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]]
+// CHECK-SAME: {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
+// CHECK: return %[[T5]] : vector<6xf32>
func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
%s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
@@ -104,19 +101,13 @@ func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
// CHECK-LABEL: func @shape_cast_1d3d
// CHECK-SAME: %[[A:.*]]: vector<6xf32>
// CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<6xf32>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : f32 from vector<6xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : f32 from vector<6xf32>
-// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : f32 from vector<6xf32>
-// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : f32 from vector<6xf32>
-// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32>
-// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : f32 from vector<6xf32>
-// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32>
-// CHECK: return %[[T11]] : vector<2x1x3xf32>
+// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
+// CHECK-SAME: {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : vector<3xf32> into vector<2x1x3xf32>
+// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]]
+// CHECK: {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : vector<3xf32> into vector<2x1x3xf32>
+// CHECK: return %[[T3]] : vector<2x1x3xf32>
func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
%s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>
|
kuhar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cherry-picked this to see if anything break in the SPIR-V compilation pipeline in IREE and it's all good, it doesn't seem to affect the insert_/extract_strided_slice decomposition.
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good to me, thanks for the patch!
| SmallVector<int64_t> srcIdx(srcRank - 1); | ||
| SmallVector<int64_t> resIdx(resRank); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[optional nit]: Interesting code, and it initializes all the values to zeros. I know it works and how it works. But it is not common in the codebase, IMO. We usually add the default value to the constructor. Would you mind to explicitly add the default value to the constructor?
| SmallVector<int64_t> srcIdx(srcRank); | ||
| SmallVector<int64_t> resIdx(resRank - 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[optional] ditto, if the above suggestion is applied, please also update the code for consistency.
| indices[dim] += step; | ||
| if (indices[dim] < vecType.getDimSize(dim)) | ||
| break; | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we add an assertion that indices[dim] == vecType.getDimSize(dim)? It looks weird to me when it happens. Assertion is a sanity check in this case.
hanhanW
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
This PR implements a generalization of the existing more efficient lowering of shape casts from 2-D to 1D and 1-D to 2-D vectors. This significantly reduces code size and generates more performant code for n-D shape casts that make their way to LLVM/SPIR-V.