Skip to content

Conversation

@dcaballe
Copy link
Contributor

This PR implements a generalization of the existing more efficient lowering of shape casts from 2-D to 1D and 1-D to 2-D vectors. This significantly reduces code size and generates more performant code for n-D shape casts that make their way to LLVM/SPIR-V.

This PR implements a generalization of the existing efficient lowering
of shape casts from 2-D to 1D and 1-D to 2-D vectors. This significantly
reduces code size and generates more performant code for n-D shape casts
that make their way to LLVM/SPIR-V.
@llvmbot
Copy link
Member

llvmbot commented Jan 19, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

This PR implements a generalization of the existing more efficient lowering of shape casts from 2-D to 1D and 1-D to 2-D vectors. This significantly reduces code size and generates more performant code for n-D shape casts that make their way to LLVM/SPIR-V.


Full diff: https://github.com/llvm/llvm-project/pull/123497.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp (+84-70)
  • (modified) mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir (+18-27)
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 95ebd4e9fe3d99..edbf798e1c673b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -11,40 +11,41 @@
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
-#include "mlir/Dialect/Arith/Utils/Utils.h"
-#include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
-#include "mlir/Dialect/SCF/IR/SCF.h"
-#include "mlir/Dialect/Tensor/IR/Tensor.h"
-#include "mlir/Dialect/Utils/IndexingUtils.h"
-#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/Dialect/Vector/Utils/VectorUtils.h"
-#include "mlir/IR/BuiltinAttributeInterfaces.h"
 #include "mlir/IR/BuiltinTypes.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/Location.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "mlir/Interfaces/VectorInterfaces.h"
 
 #define DEBUG_TYPE "vector-shape-cast-lowering"
 
 using namespace mlir;
 using namespace mlir::vector;
 
+/// Increments n-D `indices` by `step` starting from the innermost dimension.
+static void incIdx(SmallVectorImpl<int64_t> &indices, VectorType vecType,
+                   int step = 1) {
+  for (int dim : llvm::reverse(llvm::seq<int>(0, indices.size()))) {
+    indices[dim] += step;
+    if (indices[dim] < vecType.getDimSize(dim))
+      break;
+
+    indices[dim] = 0;
+    step = 1;
+  }
+}
+
 namespace {
-/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D
-/// vectors progressively on the way to target llvm.matrix intrinsics.
-/// This iterates over the most major dimension of the 2-D vector and performs
-/// rewrites into:
-///   vector.extract from 2-D + vector.insert_strided_slice offset into 1-D
-class ShapeCastOp2DDownCastRewritePattern
+/// ShapeOp n-D -> 1-D downcast serves the purpose of flattening N-D to 1-D
+/// vectors progressively. This iterates over the n-1 major dimensions of the
+/// n-D vector and performs rewrites into:
+///   vector.extract from n-D + vector.insert_strided_slice offset into 1-D
+class ShapeCastOpNDDownCastRewritePattern
     : public OpRewritePattern<vector::ShapeCastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -53,35 +54,52 @@ class ShapeCastOp2DDownCastRewritePattern
                                 PatternRewriter &rewriter) const override {
     auto sourceVectorType = op.getSourceVectorType();
     auto resultVectorType = op.getResultVectorType();
-
     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
       return failure();
 
-    if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1)
+    int64_t srcRank = sourceVectorType.getRank();
+    int64_t resRank = resultVectorType.getRank();
+    if (srcRank < 2 || resRank != 1)
       return failure();
 
+    // Compute the number of 1-D vector elements involved in the reshape.
+    int64_t numElts = 1;
+    for (int64_t dim = 0; dim < srcRank - 1; ++dim)
+      numElts *= sourceVectorType.getDimSize(dim);
+
     auto loc = op.getLoc();
-    Value desc = rewriter.create<arith::ConstantOp>(
+    SmallVector<int64_t> srcIdx(srcRank - 1);
+    SmallVector<int64_t> resIdx(resRank);
+    int64_t extractSize = sourceVectorType.getShape().back();
+    Value result = rewriter.create<arith::ConstantOp>(
         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
-    unsigned mostMinorVectorSize = sourceVectorType.getShape()[1];
-    for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) {
-      Value vec = rewriter.create<vector::ExtractOp>(loc, op.getSource(), i);
-      desc = rewriter.create<vector::InsertStridedSliceOp>(
-          loc, vec, desc,
-          /*offsets=*/i * mostMinorVectorSize, /*strides=*/1);
+
+    // Compute the indices of each 1-D vector element of the source extraction
+    // and destination slice insertion and generate such instructions.
+    for (int64_t i = 0; i < numElts; ++i) {
+      if (i != 0) {
+        incIdx(srcIdx, sourceVectorType, /*step=*/1);
+        incIdx(resIdx, resultVectorType, /*step=*/extractSize);
+      }
+
+      Value extract =
+          rewriter.create<vector::ExtractOp>(loc, op.getSource(), srcIdx);
+      result = rewriter.create<vector::InsertStridedSliceOp>(
+          loc, extract, result,
+          /*offsets=*/resIdx, /*strides=*/1);
     }
-    rewriter.replaceOp(op, desc);
+
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
 
-/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D
-/// vectors progressively.
-/// This iterates over the most major dimension of the 2-D vector and performs
-/// rewrites into:
-///   vector.extract_strided_slice from 1-D + vector.insert into 2-D
+/// ShapeOp 1-D -> n-D upcast serves the purpose of unflattening n-D from 1-D
+/// vectors progressively. This iterates over the n-1 major dimension of the n-D
+/// vector and performs rewrites into:
+///   vector.extract_strided_slice from 1-D + vector.insert into n-D
 /// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle.
-class ShapeCastOp2DUpCastRewritePattern
+class ShapeCastOpNDUpCastRewritePattern
     : public OpRewritePattern<vector::ShapeCastOp> {
 public:
   using OpRewritePattern::OpRewritePattern;
@@ -90,43 +108,43 @@ class ShapeCastOp2DUpCastRewritePattern
                                 PatternRewriter &rewriter) const override {
     auto sourceVectorType = op.getSourceVectorType();
     auto resultVectorType = op.getResultVectorType();
-
     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
       return failure();
 
-    if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2)
+    int64_t srcRank = sourceVectorType.getRank();
+    int64_t resRank = resultVectorType.getRank();
+    if (srcRank != 1 || resRank < 2)
       return failure();
 
+    // Compute the number of 1-D vector elements involved in the reshape.
+    int64_t numElts = 1;
+    for (int64_t dim = 0; dim < resRank - 1; ++dim)
+      numElts *= resultVectorType.getDimSize(dim);
+
+    // Compute the indices of each 1-D vector element of the source slice
+    // extraction and destination insertion and generate such instructions.
     auto loc = op.getLoc();
-    Value desc = rewriter.create<arith::ConstantOp>(
+    SmallVector<int64_t> srcIdx(srcRank);
+    SmallVector<int64_t> resIdx(resRank - 1);
+    int64_t extractSize = resultVectorType.getShape().back();
+    Value result = rewriter.create<arith::ConstantOp>(
         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
-    unsigned mostMinorVectorSize = resultVectorType.getShape()[1];
-    for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) {
-      Value vec = rewriter.create<vector::ExtractStridedSliceOp>(
-          loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize,
-          /*sizes=*/mostMinorVectorSize,
+    for (int64_t i = 0; i < numElts; ++i) {
+      if (i != 0) {
+        incIdx(srcIdx, sourceVectorType, /*step=*/extractSize);
+        incIdx(resIdx, resultVectorType, /*step=*/1);
+      }
+
+      Value extract = rewriter.create<vector::ExtractStridedSliceOp>(
+          loc, op.getSource(), /*offsets=*/srcIdx, /*sizes=*/extractSize,
           /*strides=*/1);
-      desc = rewriter.create<vector::InsertOp>(loc, vec, desc, i);
+      result = rewriter.create<vector::InsertOp>(loc, extract, result, resIdx);
     }
-    rewriter.replaceOp(op, desc);
+    rewriter.replaceOp(op, result);
     return success();
   }
 };
 
-static void incIdx(llvm::MutableArrayRef<int64_t> idx, VectorType tp,
-                   int dimIdx, int initialStep = 1) {
-  int step = initialStep;
-  for (int d = dimIdx; d >= 0; d--) {
-    idx[d] += step;
-    if (idx[d] >= tp.getDimSize(d)) {
-      idx[d] = 0;
-      step = 1;
-    } else {
-      break;
-    }
-  }
-}
-
 // We typically should not lower general shape cast operations into data
 // movement instructions, since the assumption is that these casts are
 // optimized away during progressive lowering. For completeness, however,
@@ -145,18 +163,14 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
     if (sourceVectorType.isScalable() || resultVectorType.isScalable())
       return failure();
 
-    // Special case 2D / 1D lowerings with better implementations.
-    // TODO: make is ND / 1D to allow generic ND -> 1D -> MD.
+    // Special case for n-D / 1-D lowerings with better implementations.
     int64_t srcRank = sourceVectorType.getRank();
     int64_t resRank = resultVectorType.getRank();
-    if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2))
+    if ((srcRank > 1 && resRank == 1) || (srcRank == 1 && resRank > 1))
       return failure();
 
     // Generic ShapeCast lowering path goes all the way down to unrolled scalar
     // extract/insert chains.
-    // TODO: consider evolving the semantics to only allow 1D source or dest and
-    // drop this potentially very expensive lowering.
-    // Compute number of elements involved in the reshape.
     int64_t numElts = 1;
     for (int64_t r = 0; r < srcRank; r++)
       numElts *= sourceVectorType.getDimSize(r);
@@ -172,8 +186,8 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
         loc, resultVectorType, rewriter.getZeroAttr(resultVectorType));
     for (int64_t i = 0; i < numElts; i++) {
       if (i != 0) {
-        incIdx(srcIdx, sourceVectorType, srcRank - 1);
-        incIdx(resIdx, resultVectorType, resRank - 1);
+        incIdx(srcIdx, sourceVectorType);
+        incIdx(resIdx, resultVectorType);
       }
 
       Value extract;
@@ -252,7 +266,7 @@ class ScalableShapeCastOpRewritePattern
     // have a single trailing scalable dimension. This is because there are no
     // legal representation of other scalable types in LLVM (and likely won't be
     // soon). There are also (currently) no operations that can index or extract
-    // from >= 2D scalable vectors or scalable vectors of fixed vectors.
+    // from >= 2-D scalable vectors or scalable vectors of fixed vectors.
     if (!isTrailingDimScalable(sourceVectorType) ||
         !isTrailingDimScalable(resultVectorType)) {
       return failure();
@@ -334,8 +348,8 @@ class ScalableShapeCastOpRewritePattern
 
       // 4. Increment the insert/extract indices, stepping by minExtractionSize
       // for the trailing dimensions.
-      incIdx(srcIdx, sourceVectorType, srcRank - 1, minExtractionSize);
-      incIdx(resIdx, resultVectorType, resRank - 1, minExtractionSize);
+      incIdx(srcIdx, sourceVectorType, /*step=*/minExtractionSize);
+      incIdx(resIdx, resultVectorType, /*step=*/minExtractionSize);
     }
 
     rewriter.replaceOp(op, result);
@@ -352,8 +366,8 @@ class ScalableShapeCastOpRewritePattern
 
 void mlir::vector::populateVectorShapeCastLoweringPatterns(
     RewritePatternSet &patterns, PatternBenefit benefit) {
-  patterns.add<ShapeCastOp2DDownCastRewritePattern,
-               ShapeCastOp2DUpCastRewritePattern, ShapeCastOpRewritePattern,
+  patterns.add<ShapeCastOpNDDownCastRewritePattern,
+               ShapeCastOpNDUpCastRewritePattern, ShapeCastOpRewritePattern,
                ScalableShapeCastOpRewritePattern>(patterns.getContext(),
                                                   benefit);
 }
diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
index f2f1211fd70eed..b4c52d5533116c 100644
--- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
+++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s --transform-interpreter --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
 
 // CHECK-LABEL: func @nop_shape_cast
 // CHECK-SAME: %[[A:.*]]: vector<16xf32>
@@ -82,19 +82,16 @@ func.func @shape_cast_2d2d(%arg0 : vector<3x2xf32>) -> vector<2x3xf32> {
 // CHECK-LABEL: func @shape_cast_3d1d
 // CHECK-SAME: %[[A:.*]]: vector<1x3x2xf32>
 // CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<6xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0, 0] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0] : f32 into vector<6xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 0, 1] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1] : f32 into vector<6xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 1, 0] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [2] : f32 into vector<6xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][0, 1, 1] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [3] : f32 into vector<6xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[A]][0, 2, 0] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [4] : f32 into vector<6xf32>
-// CHECK: %[[T10:.*]] = vector.extract %[[A]][0, 2, 1] : f32 from vector<1x3x2xf32>
-// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [5] : f32 into vector<6xf32>
-// CHECK: return %[[T11]] : vector<6xf32>
+// CHECK: %[[T0:.*]] = vector.extract %[[A]][0, 0] : vector<2xf32> from vector<1x3x2xf32>
+// CHECK: %[[T1:.*]] = vector.insert_strided_slice %[[T0]], %[[C]]
+// CHECK-SAME:           {offsets = [0], strides = [1]} : vector<2xf32> into vector<6xf32>
+// CHECK: %[[T2:.*]] = vector.extract %[[A]][0, 1] : vector<2xf32> from vector<1x3x2xf32>
+// CHECK: %[[T3:.*]] = vector.insert_strided_slice %[[T2]], %[[T1]]
+// CHECK-SAME:           {offsets = [2], strides = [1]} : vector<2xf32> into vector<6xf32>
+// CHECK: %[[T4:.*]] = vector.extract %[[A]][0, 2] : vector<2xf32> from vector<1x3x2xf32>
+// CHECK: %[[T5:.*]] = vector.insert_strided_slice %[[T4]], %[[T3]]
+// CHECK-SAME:           {offsets = [4], strides = [1]} : vector<2xf32> into vector<6xf32>
+// CHECK: return %[[T5]] : vector<6xf32>
 
 func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
   %s = vector.shape_cast %arg0 : vector<1x3x2xf32> to vector<6xf32>
@@ -104,19 +101,13 @@ func.func @shape_cast_3d1d(%arg0 : vector<1x3x2xf32>) -> vector<6xf32> {
 // CHECK-LABEL: func @shape_cast_1d3d
 // CHECK-SAME: %[[A:.*]]: vector<6xf32>
 // CHECK: %[[C:.*]] = arith.constant dense<0.000000e+00> : vector<2x1x3xf32>
-// CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : f32 from vector<6xf32>
-// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0, 0] : f32 into vector<2x1x3xf32>
-// CHECK: %[[T2:.*]] = vector.extract %[[A]][1] : f32 from vector<6xf32>
-// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [0, 0, 1] : f32 into vector<2x1x3xf32>
-// CHECK: %[[T4:.*]] = vector.extract %[[A]][2] : f32 from vector<6xf32>
-// CHECK: %[[T5:.*]] = vector.insert %[[T4]], %[[T3]] [0, 0, 2] : f32 into vector<2x1x3xf32>
-// CHECK: %[[T6:.*]] = vector.extract %[[A]][3] : f32 from vector<6xf32>
-// CHECK: %[[T7:.*]] = vector.insert %[[T6]], %[[T5]] [1, 0, 0] : f32 into vector<2x1x3xf32>
-// CHECK: %[[T8:.*]] = vector.extract %[[A]][4] : f32 from vector<6xf32>
-// CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T7]] [1, 0, 1] : f32 into vector<2x1x3xf32>
-// CHECK: %[[T10:.*]] = vector.extract %[[A]][5] : f32 from vector<6xf32>
-// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[T9]] [1, 0, 2] : f32 into vector<2x1x3xf32>
-// CHECK: return %[[T11]] : vector<2x1x3xf32>
+// CHECK: %[[T0:.*]] = vector.extract_strided_slice %[[A]]
+// CHECK-SAME:           {offsets = [0], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
+// CHECK: %[[T1:.*]] = vector.insert %[[T0]], %[[C]] [0, 0] : vector<3xf32> into vector<2x1x3xf32>
+// CHECK: %[[T2:.*]] = vector.extract_strided_slice %[[A]]
+// CHECK:                {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32>
+// CHECK: %[[T3:.*]] = vector.insert %[[T2]], %[[T1]] [1, 0] : vector<3xf32> into vector<2x1x3xf32>
+// CHECK: return %[[T3]] : vector<2x1x3xf32>
 
 func.func @shape_cast_1d3d(%arg0 : vector<6xf32>) -> vector<2x1x3xf32> {
   %s = vector.shape_cast %arg0 : vector<6xf32> to vector<2x1x3xf32>

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cherry-picked this to see if anything break in the SPIR-V compilation pipeline in IREE and it's all good, it doesn't seem to affect the insert_/extract_strided_slice decomposition.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good to me, thanks for the patch!

Comment on lines 71 to 72
SmallVector<int64_t> srcIdx(srcRank - 1);
SmallVector<int64_t> resIdx(resRank);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional nit]: Interesting code, and it initializes all the values to zeros. I know it works and how it works. But it is not common in the codebase, IMO. We usually add the default value to the constructor. Would you mind to explicitly add the default value to the constructor?

Comment on lines 127 to 128
SmallVector<int64_t> srcIdx(srcRank);
SmallVector<int64_t> resIdx(resRank - 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional] ditto, if the above suggestion is applied, please also update the code for consistency.

indices[dim] += step;
if (indices[dim] < vecType.getDimSize(dim))
break;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we add an assertion that indices[dim] == vecType.getDimSize(dim)? It looks weird to me when it happens. Assertion is a sanity check in this case.

Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@dcaballe dcaballe merged commit a7a4c16 into llvm:main Jan 27, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants