Skip to content

Conversation

@nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Nov 21, 2025

This PR adds unrolling for vector.create_mask op based on the targetShape. Each unrolled vector computes its local mask size in each dimension (d) as:
min(max(originalMaskSize[d] - offset[d], 0), unrolledMaskSize[d]).

@llvmbot
Copy link
Member

llvmbot commented Nov 21, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Nishant Patel (nbpatel)

Changes

This PR adds unrolling for vector.create_mask op based on the targetShape. Each unrolled vector computes its local mask size in each dimension (d) as:
min(max(originalMaskSize[d] - offset[d], 0), unrolledMaskSize[d]).


Full diff: https://github.com/llvm/llvm-project/pull/169119.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/IR/VectorOps.td (+3-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp (+93-2)
  • (modified) mlir/test/Dialect/Vector/vector-unroll-options.mlir (+39)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+6)
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43ebcaa03a470..d8ed46c2820fe 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2605,7 +2605,9 @@ def Vector_ConstantMaskOp :
 }
 
 def Vector_CreateMaskOp :
-  Vector_Op<"create_mask", [Pure]>,
+  Vector_Op<"create_mask", [Pure, 
+   DeclareOpInterfaceMethods<VectorUnrollOpInterface>
+   ]>,
     Arguments<(ins Variadic<Index>:$operands)>,
     Results<(outs VectorOfAnyRankOf<[I1]>)> {
   let summary = "creates a vector mask";
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index b60f80534bfb6..ef239e0b40b04 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1003,6 +1003,96 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
   vector::UnrollVectorOptions options;
 };
 
+/// This pattern unrolls `vector.create_mask` operations into smaller mask
+/// operations based on the target unroll shape. Each unrolled slice computes
+/// its local mask size in each dimension (d) as:
+/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
+/// Example:
+///   Given a create_mask operation:
+///     %0 = vector.create_mask %c6, %c10 : vector<8x16xi1>  // mask first 6x10
+///     elements
+///
+///   and a target unroll shape of <4x8>, the pattern produces:
+///
+///     %false = arith.constant dense<false> : vector<8x16xi1>
+///
+///     Slice [0,0]:
+///     mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
+///     %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
+///     %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+///     Slice [0,8]:
+///     mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
+///     %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
+///     %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+///     Slice [4,0]:
+///     mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
+///     %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
+///     %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+///     Slice [4,8]:
+///     mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
+///     %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
+///     %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
+///       : vector<4x8xi1> into vector<8x16xi1>
+struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
+  UnrollCreateMaskPattern(MLIRContext *context,
+                          const vector::UnrollVectorOptions &options,
+                          PatternBenefit benefit = 1)
+      : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
+        options(options) {}
+
+  LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
+                                PatternRewriter &rewriter) const override {
+    auto targetShape = getTargetShape(options, createMaskOp);
+    if (!targetShape)
+      return failure();
+
+    VectorType resultType = createMaskOp.getVectorType();
+    SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
+    Location loc = createMaskOp.getLoc();
+
+    Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+                                             rewriter.getZeroAttr(resultType));
+    auto targetVectorType = VectorType::get(*targetShape, rewriter.getI1Type());
+    SmallVector<int64_t> strides(targetShape->size(), 1);
+
+    // In each dimension (d), each unrolled vector computes its mask size as:
+    // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(originalSize, *targetShape)) {
+      SmallVector<Value> unrolledOperands;
+
+      for (auto [i, originalMaskOperand] :
+           llvm::enumerate(createMaskOp.getOperands())) {
+        Value offsetVal =
+            arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
+        Value adjustedMaskSize = arith::SubIOp::create(
+            rewriter, loc, originalMaskOperand, offsetVal);
+        Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+        Value unrolledDimSize =
+            arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
+        Value nonNegative =
+            arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
+        Value unrolledOperand =
+            arith::MinSIOp::create(rewriter, loc, nonNegative, unrolledDimSize);
+        unrolledOperands.push_back(unrolledOperand);
+      }
+
+      auto unrolledMask = vector::CreateMaskOp::create(
+          rewriter, loc, targetVectorType, unrolledOperands);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, unrolledMask, result, offsets, strides);
+    }
+    rewriter.replaceOp(createMaskOp, result);
+    return success();
+  }
+
+private:
+  vector::UnrollVectorOptions options;
+};
+
 /// Checks whether extractShape is a contiguous slice of shape.
 /// For extractShape to be contiguous in shape:
 /// 1) All but the leading dimension of extractShape and shape must match
@@ -1202,8 +1292,9 @@ void mlir::vector::populateVectorUnrollPatterns(
                UnrollReductionPattern, UnrollMultiReductionPattern,
                UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
                UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
-               UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
-      patterns.getContext(), options, benefit);
+               UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
+               UnrollCreateMaskPattern>(patterns.getContext(), options,
+                                        benefit);
 }
 
 void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index dec32e1c61a9b..70a34f227802e 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -497,6 +497,45 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
 // CHECK-NOT: arith.addf
 // CHECK: return
 
+func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1> {
+  %0 = vector.create_mask %size1, %size2 : vector<16x16xi1>
+  return %0 : vector<16x16xi1>
+}
+
+// CHECK-LABEL: func @vector_create_mask
+// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<16x16xi1>
+//       CHECK:   %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
+//       CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//       CHECK:   %[[C8:.*]] = arith.constant 8 : index
+//       CHECK:   %[[MAX0:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index
+//       CHECK:   %[[MIN0:.*]] = arith.minsi %[[MAX0]], %[[C8]] : index
+//       CHECK:   %[[MAX1:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index
+//       CHECK:   %[[MIN1:.*]] = arith.minsi %[[MAX1]], %[[C8]] : index
+//       CHECK:   %[[MASK00:.*]] = vector.create_mask %[[MIN0]], %[[MIN1]] : vector<8x8xi1>
+//       CHECK:   %[[INS00:.*]] = vector.insert_strided_slice %[[MASK00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   %[[MAX0_2:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index
+//       CHECK:   %[[MIN0_2:.*]] = arith.minsi %[[MAX0_2]], %[[C8]] : index
+//       CHECK:   %[[SUB1:.*]] = arith.subi %[[ARG1]], %[[C8]] : index
+//       CHECK:   %[[MAX1_2:.*]] = arith.maxsi %[[SUB1]], %[[C0]] : index
+//       CHECK:   %[[MIN1_2:.*]] = arith.minsi %[[MAX1_2]], %[[C8]] : index
+//       CHECK:   %[[MASK01:.*]] = vector.create_mask %[[MIN0_2]], %[[MIN1_2]] : vector<8x8xi1>
+//       CHECK:   %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   %[[SUB0:.*]] = arith.subi %[[ARG0]], %[[C8]] : index
+//       CHECK:   %[[MAX0_3:.*]] = arith.maxsi %[[SUB0]], %[[C0]] : index
+//       CHECK:   %[[MIN0_3:.*]] = arith.minsi %[[MAX0_3]], %[[C8]] : index
+//       CHECK:   %[[MAX1_3:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index
+//       CHECK:   %[[MIN1_3:.*]] = arith.minsi %[[MAX1_3]], %[[C8]] : index
+//       CHECK:   %[[MASK10:.*]] = vector.create_mask %[[MIN0_3]], %[[MIN1_3]] : vector<8x8xi1>
+//       CHECK:   %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   %[[SUB0_2:.*]] = arith.subi %[[ARG0]], %[[C8]] : index
+//       CHECK:   %[[MAX0_4:.*]] = arith.maxsi %[[SUB0_2]], %[[C0]] : index
+//       CHECK:   %[[MIN0_4:.*]] = arith.minsi %[[MAX0_4]], %[[C8]] : index
+//       CHECK:   %[[SUB1_2:.*]] = arith.subi %[[ARG1]], %[[C8]] : index
+//       CHECK:   %[[MAX1_4:.*]] = arith.maxsi %[[SUB1_2]], %[[C0]] : index
+//       CHECK:   %[[MIN1_4:.*]] = arith.minsi %[[MAX1_4]], %[[C8]] : index
+//       CHECK:   %[[MASK11:.*]] = vector.create_mask %[[MIN0_4]], %[[MIN1_4]] : vector<8x8xi1>
+//       CHECK:   %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+//       CHECK:   return %[[INS11]] : vector<16x16xi1>
 
 func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
   %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index e8ea0cc02d7f6..91e8cf72e64c3 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -178,6 +178,12 @@ struct TestVectorUnrollingPatterns
                                      .setFilterConstraint([](Operation *op) {
                                        return success(isa<vector::StepOp>(op));
                                      }));
+    populateVectorUnrollPatterns(patterns,
+                                UnrollVectorOptions()
+                                    .setNativeShape(ArrayRef<int64_t>{8, 8})
+                                    .setFilterConstraint([](Operation *op) {
+                                      return success(isa<vector::CreateMaskOp>(op));
+                                    }));
     populateVectorUnrollPatterns(
         patterns,
         UnrollVectorOptions()

@nbpatel nbpatel requested a review from Jianhui-Li November 21, 2025 22:56
@github-actions
Copy link

github-actions bot commented Nov 21, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@github-actions
Copy link

github-actions bot commented Nov 21, 2025

🐧 Linux x64 Test Results

  • 7156 tests passed
  • 594 tests skipped

@kuhar kuhar requested a review from amd-eochoalo November 21, 2025 23:59
Copy link
Contributor

@amd-eochoalo amd-eochoalo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just that small nit. Thanks!

@nbpatel
Copy link
Contributor Author

nbpatel commented Nov 26, 2025

pinging for review @banach-space @kuhar

auto unrolledMask = vector::CreateMaskOp::create(
rewriter, loc, targetVectorType, unrolledOperands);
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
loc, unrolledMask, result, offsets, strides);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd need to use createOrFold for many if not all of the previous operations. For the example that you provided in the documentation, I'd expect everything to fold into a single constant mask op. We should make sure that such a folding happens in the tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out..changed it

// CHECK: return

func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1> {
%0 = vector.create_mask %size1, %size2 : vector<16x16xi1>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please, add a test with (full or partial?) constant indices to make sure that folding happens

@nbpatel
Copy link
Contributor Author

nbpatel commented Dec 2, 2025

@dcaballe any further comments?

@nbpatel nbpatel requested a review from dcaballe December 3, 2025 03:48
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, thanks!

@nbpatel nbpatel merged commit 7931e2f into llvm:main Dec 3, 2025
10 checks passed
kcloudy0717 pushed a commit to kcloudy0717/llvm-project that referenced this pull request Dec 4, 2025
This PR adds unrolling for vector.create_mask op based on the
targetShape. Each unrolled vector computes its local mask size in each
dimension (d) as:
min(max(originalMaskSize[d] - offset[d], 0), unrolledMaskSize[d]).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants