Skip to content

Commit 60d78ed

Browse files
nbpatelkcloudy0717
authored andcommitted
[MLIR][Vector] Add unroll pattern for vector.create_mask (llvm#169119)
This PR adds unrolling for vector.create_mask op based on the targetShape. Each unrolled vector computes its local mask size in each dimension (d) as: min(max(originalMaskSize[d] - offset[d], 0), unrolledMaskSize[d]).
1 parent e8abadc commit 60d78ed

File tree

4 files changed

+158
-3
lines changed

4 files changed

+158
-3
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2605,7 +2605,9 @@ def Vector_ConstantMaskOp :
26052605
}
26062606

26072607
def Vector_CreateMaskOp :
2608-
Vector_Op<"create_mask", [Pure]>,
2608+
Vector_Op<"create_mask", [Pure,
2609+
DeclareOpInterfaceMethods<VectorUnrollOpInterface>
2610+
]>,
26092611
Arguments<(ins Variadic<Index>:$operands)>,
26102612
Results<(outs VectorOfAnyRankOf<[I1]>)> {
26112613
let summary = "creates a vector mask";

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,97 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
10031003
vector::UnrollVectorOptions options;
10041004
};
10051005

1006+
/// This pattern unrolls `vector.create_mask` operations into smaller mask
1007+
/// operations based on the target unroll shape. Each unrolled slice computes
1008+
/// its local mask size in each dimension (d) as:
1009+
/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
1010+
/// Example:
1011+
/// Given a create_mask operation:
1012+
/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10
1013+
/// elements
1014+
///
1015+
/// and a target unroll shape of <4x8>, the pattern produces:
1016+
///
1017+
/// %false = arith.constant dense<false> : vector<8x16xi1>
1018+
///
1019+
/// Slice [0,0]:
1020+
/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
1021+
/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
1022+
/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
1023+
/// : vector<4x8xi1> into vector<8x16xi1>
1024+
/// Slice [0,8]:
1025+
/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
1026+
/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
1027+
/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
1028+
/// : vector<4x8xi1> into vector<8x16xi1>
1029+
/// Slice [4,0]:
1030+
/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
1031+
/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
1032+
/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
1033+
/// : vector<4x8xi1> into vector<8x16xi1>
1034+
/// Slice [4,8]:
1035+
/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
1036+
/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
1037+
/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
1038+
/// : vector<4x8xi1> into vector<8x16xi1>
1039+
struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
1040+
UnrollCreateMaskPattern(MLIRContext *context,
1041+
const vector::UnrollVectorOptions &options,
1042+
PatternBenefit benefit = 1)
1043+
: OpRewritePattern<vector::CreateMaskOp>(context, benefit),
1044+
options(options) {}
1045+
1046+
LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
1047+
PatternRewriter &rewriter) const override {
1048+
auto targetShape = getTargetShape(options, createMaskOp);
1049+
if (!targetShape)
1050+
return failure();
1051+
1052+
VectorType resultType = createMaskOp.getVectorType();
1053+
SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
1054+
Location loc = createMaskOp.getLoc();
1055+
1056+
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
1057+
rewriter.getZeroAttr(resultType));
1058+
VectorType targetVectorType =
1059+
VectorType::get(*targetShape, rewriter.getI1Type());
1060+
SmallVector<int64_t> strides(targetShape->size(), 1);
1061+
1062+
// In each dimension (d), each unrolled vector computes its mask size as:
1063+
// min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
1064+
for (SmallVector<int64_t> offsets :
1065+
StaticTileOffsetRange(originalSize, *targetShape)) {
1066+
SmallVector<Value> unrolledOperands;
1067+
1068+
for (auto [i, originalMaskOperand] :
1069+
llvm::enumerate(createMaskOp.getOperands())) {
1070+
Value offsetVal =
1071+
arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
1072+
Value adjustedMaskSize = rewriter.createOrFold<arith::SubIOp>(
1073+
loc, originalMaskOperand, offsetVal);
1074+
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1075+
Value unrolledDimSize =
1076+
arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
1077+
Value nonNegative =
1078+
rewriter.createOrFold<arith::MaxSIOp>(loc, adjustedMaskSize, zero);
1079+
Value unrolledOperand = rewriter.createOrFold<arith::MinSIOp>(
1080+
loc, nonNegative, unrolledDimSize);
1081+
unrolledOperands.push_back(unrolledOperand);
1082+
}
1083+
1084+
auto unrolledMask = rewriter.createOrFold<vector::CreateMaskOp>(
1085+
loc, targetVectorType, unrolledOperands);
1086+
result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
1087+
loc, unrolledMask, result, offsets, strides);
1088+
}
1089+
rewriter.replaceOp(createMaskOp, result);
1090+
return success();
1091+
}
1092+
1093+
private:
1094+
vector::UnrollVectorOptions options;
1095+
};
1096+
10061097
/// Checks whether extractShape is a contiguous slice of shape.
10071098
/// For extractShape to be contiguous in shape:
10081099
/// 1) All but the leading dimension of extractShape and shape must match
@@ -1202,8 +1293,9 @@ void mlir::vector::populateVectorUnrollPatterns(
12021293
UnrollReductionPattern, UnrollMultiReductionPattern,
12031294
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
12041295
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1205-
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
1206-
patterns.getContext(), options, benefit);
1296+
UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
1297+
UnrollCreateMaskPattern>(patterns.getContext(), options,
1298+
benefit);
12071299
}
12081300

12091301
void mlir::vector::populateVectorToElementsUnrollPatterns(

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,61 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
497497
// CHECK-NOT: arith.addf
498498
// CHECK: return
499499

500+
func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1> {
501+
%0 = vector.create_mask %size1, %size2 : vector<16x16xi1>
502+
return %0 : vector<16x16xi1>
503+
}
504+
505+
// CHECK-LABEL: func @vector_create_mask
506+
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<16x16xi1>
507+
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
508+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
509+
// CHECK: %[[C8:.*]] = arith.constant 8 : index
510+
// CHECK: %[[MAX0:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index
511+
// CHECK: %[[MIN0:.*]] = arith.minsi %[[MAX0]], %[[C8]] : index
512+
// CHECK: %[[MAX1:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index
513+
// CHECK: %[[MIN1:.*]] = arith.minsi %[[MAX1]], %[[C8]] : index
514+
// CHECK: %[[MASK00:.*]] = vector.create_mask %[[MIN0]], %[[MIN1]] : vector<8x8xi1>
515+
// CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[MASK00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
516+
// CHECK: %[[MAX0_2:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index
517+
// CHECK: %[[MIN0_2:.*]] = arith.minsi %[[MAX0_2]], %[[C8]] : index
518+
// CHECK: %[[SUB1:.*]] = arith.subi %[[ARG1]], %[[C8]] : index
519+
// CHECK: %[[MAX1_2:.*]] = arith.maxsi %[[SUB1]], %[[C0]] : index
520+
// CHECK: %[[MIN1_2:.*]] = arith.minsi %[[MAX1_2]], %[[C8]] : index
521+
// CHECK: %[[MASK01:.*]] = vector.create_mask %[[MIN0_2]], %[[MIN1_2]] : vector<8x8xi1>
522+
// CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
523+
// CHECK: %[[SUB0:.*]] = arith.subi %[[ARG0]], %[[C8]] : index
524+
// CHECK: %[[MAX0_3:.*]] = arith.maxsi %[[SUB0]], %[[C0]] : index
525+
// CHECK: %[[MIN0_3:.*]] = arith.minsi %[[MAX0_3]], %[[C8]] : index
526+
// CHECK: %[[MAX1_3:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index
527+
// CHECK: %[[MIN1_3:.*]] = arith.minsi %[[MAX1_3]], %[[C8]] : index
528+
// CHECK: %[[MASK10:.*]] = vector.create_mask %[[MIN0_3]], %[[MIN1_3]] : vector<8x8xi1>
529+
// CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
530+
// CHECK: %[[SUB0_2:.*]] = arith.subi %[[ARG0]], %[[C8]] : index
531+
// CHECK: %[[MAX0_4:.*]] = arith.maxsi %[[SUB0_2]], %[[C0]] : index
532+
// CHECK: %[[MIN0_4:.*]] = arith.minsi %[[MAX0_4]], %[[C8]] : index
533+
// CHECK: %[[SUB1_2:.*]] = arith.subi %[[ARG1]], %[[C8]] : index
534+
// CHECK: %[[MAX1_4:.*]] = arith.maxsi %[[SUB1_2]], %[[C0]] : index
535+
// CHECK: %[[MIN1_4:.*]] = arith.minsi %[[MAX1_4]], %[[C8]] : index
536+
// CHECK: %[[MASK11:.*]] = vector.create_mask %[[MIN0_4]], %[[MIN1_4]] : vector<8x8xi1>
537+
// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
538+
// CHECK: return %[[INS11]] : vector<16x16xi1>
539+
540+
func.func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> {
541+
%cst16 = arith.constant 16 : index
542+
%0 = vector.create_mask %cst16, %cst16 : vector<16x16xi1>
543+
return %0 : vector<16x16xi1>
544+
}
545+
546+
// CHECK-LABEL: func @vector_create_mask_constant_dim_sizes() -> vector<16x16xi1> {
547+
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
548+
// CHECK: %[[CST_0:.*]] = arith.constant dense<true> : vector<8x8xi1>
549+
// CHECK: %[[S0:.*]] = vector.insert_strided_slice %[[CST_0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
550+
// CHECK: %[[S1:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S0]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
551+
// CHECK: %[[S2:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S1]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
552+
// CHECK: %[[S3:.*]] = vector.insert_strided_slice %[[CST_0]], %[[S2]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
553+
// CHECK: return %[[S3]] : vector<16x16xi1>
554+
500555

501556
func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
502557
%0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@ struct TestVectorUnrollingPatterns
178178
.setFilterConstraint([](Operation *op) {
179179
return success(isa<vector::StepOp>(op));
180180
}));
181+
populateVectorUnrollPatterns(
182+
patterns, UnrollVectorOptions()
183+
.setNativeShape(ArrayRef<int64_t>{8, 8})
184+
.setFilterConstraint([](Operation *op) {
185+
return success(isa<vector::CreateMaskOp>(op));
186+
}));
181187
populateVectorUnrollPatterns(
182188
patterns,
183189
UnrollVectorOptions()

0 commit comments

Comments
 (0)