Skip to content

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Oct 8, 2025

This PR enhances the elementwise unrolling pattern to support higher rank to lower rank unroll. The approach is to add leading unit dims to lower rank targetShape to match the rank of original vector (because ExtractStridedSlice requires same rank to extractSlices), extract slice, reshape to targetShape's rank and perform the operation.

@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/162515.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+38-14)
  • (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+54-6)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 14639c5f1cdd3..62d65e28e8c2e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -468,23 +468,30 @@ struct UnrollElementwisePattern : public RewritePattern {
     auto dstVecType = cast<VectorType>(op->getResult(0).getType());
     SmallVector<int64_t> originalSize =
         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
-    // Bail-out if rank(source) != rank(target). The main limitation here is the
-    // fact that `ExtractStridedSlice` requires the rank for the input and
-    // output to match. If needed, we can relax this later.
-    if (originalSize.size() != targetShape->size())
-      return rewriter.notifyMatchFailure(
-          op, "expected input vector rank to match target shape rank");
+
     Location loc = op->getLoc();
+
+    // Handle rank mismatch by adding leading unit dimensions to targetShape
+    SmallVector<int64_t> adjustedTargetShape = *targetShape;
+    SmallVector<int64_t> adjustedOffsets;
+    if (originalSize.size() > targetShape->size()) {
+      // Add leading unit dimensions to targetShape
+      int64_t rankDiff = originalSize.size() - targetShape->size();
+      adjustedTargetShape.insert(adjustedTargetShape.begin(), rankDiff, 1);
+    }
+
     // Prepare the result vector.
     Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
                                              rewriter.getZeroAttr(dstVecType));
-    SmallVector<int64_t> strides(targetShape->size(), 1);
-    VectorType newVecType =
+    SmallVector<int64_t> strides(adjustedTargetShape.size(), 1);
+    VectorType extractVecType =
+        VectorType::get(adjustedTargetShape, dstVecType.getElementType());
+    VectorType computeVecType =
         VectorType::get(*targetShape, dstVecType.getElementType());
 
     // Create the unrolled computation.
     for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(originalSize, *targetShape)) {
+         StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
       SmallVector<Value> extractOperands;
       for (OpOperand &operand : op->getOpOperands()) {
         auto vecType = dyn_cast<VectorType>(operand.get().getType());
@@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern {
           extractOperands.push_back(operand.get());
           continue;
         }
-        extractOperands.push_back(
-            rewriter.createOrFold<vector::ExtractStridedSliceOp>(
-                loc, operand.get(), offsets, *targetShape, strides));
+        Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+            loc, operand.get(), offsets, adjustedTargetShape, strides);
+
+        // Reshape to remove leading unit dims if needed
+        if (adjustedTargetShape.size() > targetShape->size()) {
+          extracted = rewriter.createOrFold<vector::ShapeCastOp>(
+              loc, VectorType::get(*targetShape, vecType.getElementType()),
+              extracted);
+        }
+        extractOperands.push_back(extracted);
       }
+
       Operation *newOp = cloneOpWithOperandsAndTypes(
-          rewriter, loc, op, extractOperands, newVecType);
+          rewriter, loc, op, extractOperands, computeVecType);
+
+      Value computeResult = newOp->getResult(0);
+
+      // Reshape back to higher rank if needed for insertion
+      if (adjustedTargetShape.size() > targetShape->size()) {
+        computeResult = rewriter.createOrFold<vector::ShapeCastOp>(
+            loc, extractVecType, computeResult);
+      }
+
       result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
-          loc, newOp->getResult(0), result, offsets, strides);
+          loc, computeResult, result, offsets, strides);
     }
     rewriter.replaceOp(op, result);
     return success();
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 35db14e0f7f1d..a26e4b0baa05b 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,15 +188,40 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
 //   CHECK-LABEL: func @vector_fma
 // CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
 
-// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
-func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
+func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
   %0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
   return %0 : vector<3x2x2xf32>
 }
-// CHECK-LABEL: func @negative_vector_fma_3d
-//   CHECK-NOT: vector.extract_strided_slice
-//       CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
-//       CHECK: return
+// CHECK-LABEL: func @vector_fma_3d
+//       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32>
+//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA0:.*]] = vector.fma %[[S0]], %[[S1]], %[[S2]] : vector<2x2xf32>
+//       CHECK:   %[[SC0:.*]] = vector.shape_cast %[[FMA0]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S4:.*]] = vector.shape_cast %[[E4]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S5:.*]] = vector.shape_cast %[[E5]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA1:.*]] = vector.fma %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>
+//       CHECK:   %[[SC1:.*]] = vector.shape_cast %[[FMA1]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S6:.*]] = vector.shape_cast %[[E6]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S7:.*]] = vector.shape_cast %[[E7]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S8:.*]] = vector.shape_cast %[[E8]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA2:.*]] = vector.fma %[[S6]], %[[S7]], %[[S8]] : vector<2x2xf32>
+//       CHECK:   %[[SC2:.*]] = vector.shape_cast %[[FMA2]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[SC2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   return %[[I2]] : vector<3x2x2xf32>
 
 func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
@@ -440,3 +465,26 @@ func.func @vector_step() -> vector<32xindex> {
 // CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
 // CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
 // CHECK: return %[[INS3]] : vector<32xindex>
+
+
+func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+  %0 = arith.addf %v1, %v2 : vector<2x2x2xf32>
+  return %0 : vector<2x2x2xf32>
+}
+// CHECK-LABEL: func @elementwise
+//       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32>
+//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[ADD0:.*]] = arith.addf %[[S0]], %[[S1]] : vector<2x2xf32>
+//       CHECK:   %[[SC0:.*]] = vector.shape_cast %[[ADD0]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[ADD1:.*]] = arith.addf %[[S2]], %[[S3]] : vector<2x2xf32>
+//       CHECK:   %[[SC1:.*]] = vector.shape_cast %[[ADD1]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+//       CHECK:   return %[[I1]] : vector<2x2x2xf32>

@nbpatel
Copy link
Contributor Author

nbpatel commented Oct 8, 2025

@newling please take a look

@nbpatel nbpatel closed this Oct 8, 2025
@nbpatel nbpatel reopened this Oct 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants