Skip to content

Conversation

@Groverkss
Copy link
Member

@Groverkss Groverkss commented Jul 2, 2025

extract_strided_slice(create_mask) can be folded into create_mask by simply subtracting the offsets from the bounds.

@llvmbot
Copy link
Member

llvmbot commented Jul 2, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Kunwar Grover (Groverkss)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/146745.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+59-3)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+17)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 1fb8c7a928e06..66c2fe50529d7 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -4081,6 +4081,62 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
 
 namespace {
 
+class StridedSliceCreateMaskFolder final
+    : public OpRewritePattern<ExtractStridedSliceOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+public:
+  LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
+                                PatternRewriter &rewriter) const override {
+    Location loc = extractStridedSliceOp.getLoc();
+    // Return if 'extractStridedSliceOp' operand is not defined by a
+    // CreateMaskOp.
+    auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
+    auto createMaskOp = dyn_cast_or_null<CreateMaskOp>(defOp);
+    if (!createMaskOp)
+      return failure();
+    // Return if 'extractStridedSliceOp' has non-unit strides.
+    if (extractStridedSliceOp.hasNonUnitStrides())
+      return failure();
+    // Gather constant mask dimension sizes.
+    SmallVector<Value> maskDimSizes(createMaskOp.getOperands());
+    // Gather strided slice offsets and sizes.
+    SmallVector<int64_t, 4> sliceOffsets;
+    populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
+                               sliceOffsets);
+    SmallVector<int64_t, 4> sliceSizes;
+    populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
+
+    // Compute slice of vector mask region.
+    SmallVector<Value> sliceMaskDimSizes;
+    sliceMaskDimSizes.reserve(maskDimSizes.size());
+    for (auto [maskDimSize, sliceOffset, sliceSize] :
+         llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
+      // No need to clamp on min/max values, because create_mask has clamping
+      // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
+      // greater than the vector dim size.
+      IntegerAttr offsetAttr =
+          rewriter.getIntegerAttr(maskDimSize.getType(), sliceOffset);
+      Value offset = rewriter.create<arith::ConstantOp>(loc, offsetAttr);
+      Value sliceMaskDimSize =
+          rewriter.create<arith::SubIOp>(loc, maskDimSize, offset);
+      sliceMaskDimSizes.push_back(sliceMaskDimSize);
+    }
+    // Add unchanged dimensions.
+    if (sliceMaskDimSizes.size() < maskDimSizes.size()) {
+      for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i) {
+        sliceMaskDimSizes.push_back(maskDimSizes[i]);
+      }
+    }
+    // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
+    // region.
+    rewriter.replaceOpWithNewOp<CreateMaskOp>(
+        extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
+        sliceMaskDimSizes);
+    return success();
+  }
+};
+
 // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
 // ConstantMaskOp.
 class StridedSliceConstantMaskFolder final
@@ -4279,9 +4335,9 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
     RewritePatternSet &results, MLIRContext *context) {
   // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
   // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
-  results.add<StridedSliceConstantMaskFolder, StridedSliceBroadcast,
-              StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
-      context);
+  results.add<StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder,
+              StridedSliceBroadcast, StridedSliceSplat,
+              ContiguousExtractStridedSliceToExtract>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 0282e9cac5e02..c09dab8232900 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -361,6 +361,23 @@ func.func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
 
 // -----
 
+// CHECK-LABEL: func.func @extract_strided_slice_of_create_mask
+// CHECK-SAME: (%[[DIM0:.+]]: index, %[[DIM1:.+]]: index)
+func.func @extract_strided_slice_of_create_mask(%dim0: index, %dim1: index) -> (vector<2x2xi1>) {
+  %0 = vector.create_mask %dim0, %dim1 : vector<4x3xi1>
+  %1 = vector.extract_strided_slice %0
+    {offsets = [2, 1], sizes = [2, 2], strides = [1, 1]}
+      : vector<4x3xi1> to vector<2x2xi1>
+  // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
+  // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+  // CHECK-DAG: %[[A:.+]] = arith.subi %[[DIM0]], %[[C2]]
+  // CHECK-DAG: %[[B:.+]] = arith.subi %[[DIM1]], %[[C1]]
+  // CHECK: vector.create_mask %[[A]], %[[B]] : vector<2x2xi1>
+  return %1 : vector<2x2xi1>
+}
+
+// -----
+
 // CHECK-LABEL: extract_strided_fold
 //  CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>)
 //  CHECK-NEXT:   return %[[ARG]] : vector<4x3xi1>

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, thanks! Just some minor comments.

I'm curious: what is the current lowering of this without the canonicalization? A vector.shuffle extracting the slice from the mask?

@Groverkss Groverkss requested review from dcaballe and joker-eph July 4, 2025 10:53
@Groverkss
Copy link
Member Author

LG, thanks! Just some minor comments.

I'm curious: what is the current lowering of this without the canonicalization? A vector.shuffle extracting the slice from the mask?

Yes. Not only that, in a lot of cases you could actually fold the mask with the maskedload into a load, but the extract_strided_slice was blocking it.

@Groverkss Groverkss requested a review from kuhar July 4, 2025 13:48
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like a good idea to me

@Groverkss Groverkss requested a review from kuhar July 9, 2025 13:08
@Groverkss Groverkss merged commit 0227aef into llvm:main Jul 10, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants