From 86425375b67f2e105af80e31d8e96b87fe22ad82 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 14 Nov 2025 21:09:50 +0000 Subject: [PATCH 1/3] Add unroll pattern for vector.create_mask --- .../mlir/Dialect/Vector/IR/VectorOps.td | 4 +- .../Vector/Transforms/VectorUnroll.cpp | 94 ++++++++++++++++++- .../Dialect/Vector/vector-unroll-options.mlir | 40 ++++++++ .../Dialect/Vector/TestVectorTransforms.cpp | 6 ++ 4 files changed, 141 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 43172ff2082df..4f9252f046eab 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2607,7 +2607,9 @@ def Vector_ConstantMaskOp : } def Vector_CreateMaskOp : - Vector_Op<"create_mask", [Pure]>, + Vector_Op<"create_mask", [Pure, + DeclareOpInterfaceMethods + ]>, Arguments<(ins Variadic:$operands)>, Results<(outs VectorOfAnyRankOf<[I1]>)> { let summary = "creates a vector mask"; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index fbae0989bed26..ca2978c5d5a19 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1003,6 +1003,96 @@ struct UnrollFromElements : OpRewritePattern { vector::UnrollVectorOptions options; }; +/// This pattern unrolls `vector.create_mask` operations into smaller mask +/// operations based on the target unroll shape. Each unrolled slice computes +/// its local mask size in each dimension (d) as: +/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]). +/// Example: +/// Given a create_mask operation: +/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10 +/// elements +/// +/// and a target unroll shape of <4x8>, the pattern produces: +/// +/// %false = arith.constant dense : vector<8x16xi1> +/// +/// Slice [0,0]: +/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8 +/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1> +/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// Slice [0,8]: +/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2 +/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1> +/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// Slice [4,0]: +/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8 +/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1> +/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +/// Slice [4,8]: +/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2 +/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1> +/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1] +/// : vector<4x8xi1> into vector<8x16xi1> +struct UnrollCreateMaskPattern : public OpRewritePattern { + UnrollCreateMaskPattern(MLIRContext *context, + const vector::UnrollVectorOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + options(options) {} + + LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp, + PatternRewriter &rewriter) const override { + auto targetShape = getTargetShape(options, createMaskOp); + if (!targetShape) + return failure(); + + VectorType resultType = createMaskOp.getVectorType(); + SmallVector originalSize = *createMaskOp.getShapeForUnroll(); + Location loc = createMaskOp.getLoc(); + + Value result = arith::ConstantOp::create(rewriter, loc, resultType, + rewriter.getZeroAttr(resultType)); + auto targetVectorType = VectorType::get(*targetShape, rewriter.getI1Type()); + SmallVector strides(targetShape->size(), 1); + + // In each dimension (d), each unrolled vector computes its mask size as: + // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]). + for (SmallVector offsets : + StaticTileOffsetRange(originalSize, *targetShape)) { + SmallVector unrolledOperands; + + for (auto [i, originalMaskOperand] : + llvm::enumerate(createMaskOp.getOperands())) { + Value offsetVal = + arith::ConstantIndexOp::create(rewriter, loc, offsets[i]); + Value adjustedMaskSize = arith::SubIOp::create( + rewriter, loc, originalMaskOperand, offsetVal); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value unrolledDimSize = + arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]); + Value nonNegative = + arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero); + Value unrolledOperand = + arith::MinSIOp::create(rewriter, loc, nonNegative, unrolledDimSize); + unrolledOperands.push_back(unrolledOperand); + } + + auto unrolledMask = vector::CreateMaskOp::create( + rewriter, loc, targetVectorType, unrolledOperands); + result = rewriter.createOrFold( + loc, unrolledMask, result, offsets, strides); + } + rewriter.replaceOp(createMaskOp, result); + return success(); + } + +private: + vector::UnrollVectorOptions options; +}; + } // namespace void mlir::vector::populateVectorUnrollPatterns( @@ -1013,8 +1103,8 @@ void mlir::vector::populateVectorUnrollPatterns( UnrollReductionPattern, UnrollMultiReductionPattern, UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, - UnrollToElements, UnrollStepPattern>(patterns.getContext(), - options, benefit); + UnrollToElements, UnrollStepPattern, UnrollCreateMaskPattern>( + patterns.getContext(), options, benefit); } void mlir::vector::populateVectorToElementsUnrollPatterns( diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index e5a98b5c67f33..f36c77ee8799f 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -496,3 +496,43 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3 // CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32> // CHECK-NOT: arith.addf // CHECK: return + +func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1> { + %0 = vector.create_mask %size1, %size2 : vector<16x16xi1> + return %0 : vector<16x16xi1> +} + +// CHECK-LABEL: func @vector_create_mask +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<16x16xi1> +// CHECK: %[[CST:.*]] = arith.constant dense : vector<16x16xi1> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C8:.*]] = arith.constant 8 : index +// CHECK: %[[MAX0:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index +// CHECK: %[[MIN0:.*]] = arith.minsi %[[MAX0]], %[[C8]] : index +// CHECK: %[[MAX1:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index +// CHECK: %[[MIN1:.*]] = arith.minsi %[[MAX1]], %[[C8]] : index +// CHECK: %[[MASK00:.*]] = vector.create_mask %[[MIN0]], %[[MIN1]] : vector<8x8xi1> +// CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[MASK00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[MAX0_2:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index +// CHECK: %[[MIN0_2:.*]] = arith.minsi %[[MAX0_2]], %[[C8]] : index +// CHECK: %[[SUB1:.*]] = arith.subi %[[ARG1]], %[[C8]] : index +// CHECK: %[[MAX1_2:.*]] = arith.maxsi %[[SUB1]], %[[C0]] : index +// CHECK: %[[MIN1_2:.*]] = arith.minsi %[[MAX1_2]], %[[C8]] : index +// CHECK: %[[MASK01:.*]] = vector.create_mask %[[MIN0_2]], %[[MIN1_2]] : vector<8x8xi1> +// CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[SUB0:.*]] = arith.subi %[[ARG0]], %[[C8]] : index +// CHECK: %[[MAX0_3:.*]] = arith.maxsi %[[SUB0]], %[[C0]] : index +// CHECK: %[[MIN0_3:.*]] = arith.minsi %[[MAX0_3]], %[[C8]] : index +// CHECK: %[[MAX1_3:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index +// CHECK: %[[MIN1_3:.*]] = arith.minsi %[[MAX1_3]], %[[C8]] : index +// CHECK: %[[MASK10:.*]] = vector.create_mask %[[MIN0_3]], %[[MIN1_3]] : vector<8x8xi1> +// CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[SUB0_2:.*]] = arith.subi %[[ARG0]], %[[C8]] : index +// CHECK: %[[MAX0_4:.*]] = arith.maxsi %[[SUB0_2]], %[[C0]] : index +// CHECK: %[[MIN0_4:.*]] = arith.minsi %[[MAX0_4]], %[[C8]] : index +// CHECK: %[[SUB1_2:.*]] = arith.subi %[[ARG1]], %[[C8]] : index +// CHECK: %[[MAX1_4:.*]] = arith.maxsi %[[SUB1_2]], %[[C0]] : index +// CHECK: %[[MIN1_4:.*]] = arith.minsi %[[MAX1_4]], %[[C8]] : index +// CHECK: %[[MASK11:.*]] = vector.create_mask %[[MIN0_4]], %[[MIN1_4]] : vector<8x8xi1> +// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: return %[[INS11]] : vector<16x16xi1> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 79bfc9bbcda71..8e69a2ab37e5e 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -178,6 +178,12 @@ struct TestVectorUnrollingPatterns .setFilterConstraint([](Operation *op) { return success(isa(op)); })); + populateVectorUnrollPatterns(patterns, + UnrollVectorOptions() + .setNativeShape(ArrayRef{8, 8}) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() .setNativeShape(ArrayRef{1, 3, 4, 2}) From f2af4230f4baa1ae541520dcb1168fc35382e3a8 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Fri, 21 Nov 2025 23:04:55 +0000 Subject: [PATCH 2/3] Fix format --- .../test/lib/Dialect/Vector/TestVectorTransforms.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 91e8cf72e64c3..f834d0cdd42bd 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -178,12 +178,12 @@ struct TestVectorUnrollingPatterns .setFilterConstraint([](Operation *op) { return success(isa(op)); })); - populateVectorUnrollPatterns(patterns, - UnrollVectorOptions() - .setNativeShape(ArrayRef{8, 8}) - .setFilterConstraint([](Operation *op) { - return success(isa(op)); - })); + populateVectorUnrollPatterns( + patterns, UnrollVectorOptions() + .setNativeShape(ArrayRef{8, 8}) + .setFilterConstraint([](Operation *op) { + return success(isa(op)); + })); populateVectorUnrollPatterns( patterns, UnrollVectorOptions() From 42522d1e9886ec564a4483810a8da82cd9573901 Mon Sep 17 00:00:00 2001 From: nbpatel Date: Thu, 27 Nov 2025 00:48:07 +0000 Subject: [PATCH 3/3] Address Feedback --- .../Dialect/Vector/Transforms/VectorUnroll.cpp | 17 +++++++++-------- .../Dialect/Vector/vector-unroll-options.mlir | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index ef239e0b40b04..462bd8c3dc4a6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -1055,7 +1055,8 @@ struct UnrollCreateMaskPattern : public OpRewritePattern { Value result = arith::ConstantOp::create(rewriter, loc, resultType, rewriter.getZeroAttr(resultType)); - auto targetVectorType = VectorType::get(*targetShape, rewriter.getI1Type()); + VectorType targetVectorType = + VectorType::get(*targetShape, rewriter.getI1Type()); SmallVector strides(targetShape->size(), 1); // In each dimension (d), each unrolled vector computes its mask size as: @@ -1068,20 +1069,20 @@ struct UnrollCreateMaskPattern : public OpRewritePattern { llvm::enumerate(createMaskOp.getOperands())) { Value offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offsets[i]); - Value adjustedMaskSize = arith::SubIOp::create( - rewriter, loc, originalMaskOperand, offsetVal); + Value adjustedMaskSize = rewriter.createOrFold( + loc, originalMaskOperand, offsetVal); Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); Value unrolledDimSize = arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]); Value nonNegative = - arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero); - Value unrolledOperand = - arith::MinSIOp::create(rewriter, loc, nonNegative, unrolledDimSize); + rewriter.createOrFold(loc, adjustedMaskSize, zero); + Value unrolledOperand = rewriter.createOrFold( + loc, nonNegative, unrolledDimSize); unrolledOperands.push_back(unrolledOperand); } - auto unrolledMask = vector::CreateMaskOp::create( - rewriter, loc, targetVectorType, unrolledOperands); + auto unrolledMask = rewriter.createOrFold( + loc, targetVectorType, unrolledOperands); result = rewriter.createOrFold( loc, unrolledMask, result, offsets, strides); } diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir index 70a34f227802e..805e66f133c59 100644 --- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir +++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir @@ -537,6 +537,22 @@ func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1> // CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> // CHECK: return %[[INS11]] : vector<16x16xi1> +func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> { + %cst16 = arith.constant 16 : index + %0 = vector.create_mask %cst16, %cst16 : vector<16x16xi1> + return %0 : vector<16x16xi1> +} + +// CHECK-LABEL: func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> { +// CHECK: %[[CST:.*]] = arith.constant dense : vector<16x16xi1> +// CHECK: %[[CST_0:.*]] = arith.constant dense : vector<8x8xi1> +// CHECK: %[[S0:.*]] = vector.insert_strided_slice %[[CST_0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[S1:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S0]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[S2:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S1]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> +// CHECK: return %[[S3]] : vector<16x16xi1> + + func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> { %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32> return %0 : vector<2x2x4xf32>