Skip to content

Conversation

@nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Oct 8, 2025

This PR enhances the elementwise unrolling pattern to support higher rank to lower rank unroll. The approach is to add leading unit dims to lower rank targetShape to match the rank of original vector (because ExtractStridedSlice requires same rank to extractSlices), extract slice, reshape to targetShape's rank and perform the operation.

@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/162515.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+38-14)
  • (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+54-6)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index 14639c5f1cdd3..62d65e28e8c2e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -468,23 +468,30 @@ struct UnrollElementwisePattern : public RewritePattern {
     auto dstVecType = cast<VectorType>(op->getResult(0).getType());
     SmallVector<int64_t> originalSize =
         *cast<VectorUnrollOpInterface>(op).getShapeForUnroll();
-    // Bail-out if rank(source) != rank(target). The main limitation here is the
-    // fact that `ExtractStridedSlice` requires the rank for the input and
-    // output to match. If needed, we can relax this later.
-    if (originalSize.size() != targetShape->size())
-      return rewriter.notifyMatchFailure(
-          op, "expected input vector rank to match target shape rank");
+
     Location loc = op->getLoc();
+
+    // Handle rank mismatch by adding leading unit dimensions to targetShape
+    SmallVector<int64_t> adjustedTargetShape = *targetShape;
+    SmallVector<int64_t> adjustedOffsets;
+    if (originalSize.size() > targetShape->size()) {
+      // Add leading unit dimensions to targetShape
+      int64_t rankDiff = originalSize.size() - targetShape->size();
+      adjustedTargetShape.insert(adjustedTargetShape.begin(), rankDiff, 1);
+    }
+
     // Prepare the result vector.
     Value result = arith::ConstantOp::create(rewriter, loc, dstVecType,
                                              rewriter.getZeroAttr(dstVecType));
-    SmallVector<int64_t> strides(targetShape->size(), 1);
-    VectorType newVecType =
+    SmallVector<int64_t> strides(adjustedTargetShape.size(), 1);
+    VectorType extractVecType =
+        VectorType::get(adjustedTargetShape, dstVecType.getElementType());
+    VectorType computeVecType =
         VectorType::get(*targetShape, dstVecType.getElementType());
 
     // Create the unrolled computation.
     for (SmallVector<int64_t> offsets :
-         StaticTileOffsetRange(originalSize, *targetShape)) {
+         StaticTileOffsetRange(originalSize, adjustedTargetShape)) {
       SmallVector<Value> extractOperands;
       for (OpOperand &operand : op->getOpOperands()) {
         auto vecType = dyn_cast<VectorType>(operand.get().getType());
@@ -492,14 +499,31 @@ struct UnrollElementwisePattern : public RewritePattern {
           extractOperands.push_back(operand.get());
           continue;
         }
-        extractOperands.push_back(
-            rewriter.createOrFold<vector::ExtractStridedSliceOp>(
-                loc, operand.get(), offsets, *targetShape, strides));
+        Value extracted = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+            loc, operand.get(), offsets, adjustedTargetShape, strides);
+
+        // Reshape to remove leading unit dims if needed
+        if (adjustedTargetShape.size() > targetShape->size()) {
+          extracted = rewriter.createOrFold<vector::ShapeCastOp>(
+              loc, VectorType::get(*targetShape, vecType.getElementType()),
+              extracted);
+        }
+        extractOperands.push_back(extracted);
       }
+
       Operation *newOp = cloneOpWithOperandsAndTypes(
-          rewriter, loc, op, extractOperands, newVecType);
+          rewriter, loc, op, extractOperands, computeVecType);
+
+      Value computeResult = newOp->getResult(0);
+
+      // Reshape back to higher rank if needed for insertion
+      if (adjustedTargetShape.size() > targetShape->size()) {
+        computeResult = rewriter.createOrFold<vector::ShapeCastOp>(
+            loc, extractVecType, computeResult);
+      }
+
       result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
-          loc, newOp->getResult(0), result, offsets, strides);
+          loc, computeResult, result, offsets, strides);
     }
     rewriter.replaceOp(op, result);
     return success();
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index 35db14e0f7f1d..a26e4b0baa05b 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -188,15 +188,40 @@ func.func @vector_fma(%a: vector<4x4xf32>, %b: vector<4x4xf32>, %c: vector<4x4xf
 //   CHECK-LABEL: func @vector_fma
 // CHECK-COUNT-4: vector.fma %{{.+}}, %{{.+}}, %{{.+}} : vector<2x2xf32>
 
-// TODO: We should be able to unroll this like the example above - this will require extending UnrollElementwisePattern.
-func.func @negative_vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
+func.func @vector_fma_3d(%a: vector<3x2x2xf32>) -> vector<3x2x2xf32>{
   %0 = vector.fma %a, %a, %a : vector<3x2x2xf32>
   return %0 : vector<3x2x2xf32>
 }
-// CHECK-LABEL: func @negative_vector_fma_3d
-//   CHECK-NOT: vector.extract_strided_slice
-//       CHECK: %[[R0:.*]] = vector.fma %{{.+}} : vector<3x2x2xf32>
-//       CHECK: return
+// CHECK-LABEL: func @vector_fma_3d
+//       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<3x2x2xf32>
+//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA0:.*]] = vector.fma %[[S0]], %[[S1]], %[[S2]] : vector<2x2xf32>
+//       CHECK:   %[[SC0:.*]] = vector.shape_cast %[[FMA0]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E4:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S4:.*]] = vector.shape_cast %[[E4]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E5:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S5:.*]] = vector.shape_cast %[[E5]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA1:.*]] = vector.fma %[[S3]], %[[S4]], %[[S5]] : vector<2x2xf32>
+//       CHECK:   %[[SC1:.*]] = vector.shape_cast %[[FMA1]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   %[[E6:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S6:.*]] = vector.shape_cast %[[E6]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E7:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S7:.*]] = vector.shape_cast %[[E7]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E8:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [2, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<3x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S8:.*]] = vector.shape_cast %[[E8]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[FMA2:.*]] = vector.fma %[[S6]], %[[S7]], %[[S8]] : vector<2x2xf32>
+//       CHECK:   %[[SC2:.*]] = vector.shape_cast %[[FMA2]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I2:.*]] = vector.insert_strided_slice %[[SC2]], %[[I1]] {offsets = [2, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<3x2x2xf32>
+//       CHECK:   return %[[I2]] : vector<3x2x2xf32>
 
 func.func @vector_multi_reduction(%v : vector<4x6xf32>, %acc: vector<4xf32>) -> vector<4xf32> {
   %0 = vector.multi_reduction #vector.kind<add>, %v, %acc [1] : vector<4x6xf32> to vector<4xf32>
@@ -440,3 +465,26 @@ func.func @vector_step() -> vector<32xindex> {
 // CHECK: %[[ADD3:.*]] = arith.addi %[[STEP]], %[[CST]] : vector<8xindex>
 // CHECK: %[[INS3:.*]] = vector.insert_strided_slice %[[ADD3]], %[[INS2]] {offsets = [24], strides = [1]} : vector<8xindex> into vector<32xindex>
 // CHECK: return %[[INS3]] : vector<32xindex>
+
+
+func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
+  %0 = arith.addf %v1, %v2 : vector<2x2x2xf32>
+  return %0 : vector<2x2x2xf32>
+}
+// CHECK-LABEL: func @elementwise
+//       CHECK:   %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x2xf32>
+//       CHECK:   %[[E0:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S0:.*]] = vector.shape_cast %[[E0]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S1:.*]] = vector.shape_cast %[[E1]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[ADD0:.*]] = arith.addf %[[S0]], %[[S1]] : vector<2x2xf32>
+//       CHECK:   %[[SC0:.*]] = vector.shape_cast %[[ADD0]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+//       CHECK:   %[[E2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S2:.*]] = vector.shape_cast %[[E2]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[E3:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [1, 0, 0], sizes = [1, 2, 2], strides = [1, 1, 1]} : vector<2x2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[S3:.*]] = vector.shape_cast %[[E3]] : vector<1x2x2xf32> to vector<2x2xf32>
+//       CHECK:   %[[ADD1:.*]] = arith.addf %[[S2]], %[[S3]] : vector<2x2xf32>
+//       CHECK:   %[[SC1:.*]] = vector.shape_cast %[[ADD1]] : vector<2x2xf32> to vector<1x2x2xf32>
+//       CHECK:   %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1, 1]} : vector<1x2x2xf32> into vector<2x2x2xf32>
+//       CHECK:   return %[[I1]] : vector<2x2x2xf32>

@nbpatel
Copy link
Contributor Author

nbpatel commented Oct 8, 2025

@newling please take a look

@nbpatel nbpatel closed this Oct 8, 2025
@nbpatel nbpatel reopened this Oct 8, 2025
@banach-space banach-space requested a review from newling October 13, 2025 15:58
@nbpatel
Copy link
Contributor Author

nbpatel commented Oct 13, 2025

Ping for review :)

@nbpatel
Copy link
Contributor Author

nbpatel commented Oct 14, 2025

@newling @dcaballe Addressed comments. Thanks

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

// CHECK: return %[[INS3]] : vector<32xindex>


func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit[ For consistency.

Suggested change
func.func @elementwise(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
func.func @elementwise_3d_to_2d(%v1: vector<2x2x2xf32>, %v2: vector<2x2x2xf32>) -> vector<2x2x2xf32> {

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % nit, thanks!

@nbpatel nbpatel merged commit 4ff8f11 into llvm:main Oct 15, 2025
10 checks passed
@nbpatel nbpatel deleted the xegpu-elementwise branch October 16, 2025 18:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants