-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][Vector] Add unroll pattern for vector.create_mask #169119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Nishant Patel (nbpatel) ChangesThis PR adds unrolling for vector.create_mask op based on the targetShape. Each unrolled vector computes its local mask size in each dimension (d) as: Full diff: https://github.com/llvm/llvm-project/pull/169119.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 43ebcaa03a470..d8ed46c2820fe 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2605,7 +2605,9 @@ def Vector_ConstantMaskOp :
}
def Vector_CreateMaskOp :
- Vector_Op<"create_mask", [Pure]>,
+ Vector_Op<"create_mask", [Pure,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface>
+ ]>,
Arguments<(ins Variadic<Index>:$operands)>,
Results<(outs VectorOfAnyRankOf<[I1]>)> {
let summary = "creates a vector mask";
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index b60f80534bfb6..ef239e0b40b04 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1003,6 +1003,96 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
vector::UnrollVectorOptions options;
};
+/// This pattern unrolls `vector.create_mask` operations into smaller mask
+/// operations based on the target unroll shape. Each unrolled slice computes
+/// its local mask size in each dimension (d) as:
+/// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]).
+/// Example:
+/// Given a create_mask operation:
+/// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10
+/// elements
+///
+/// and a target unroll shape of <4x8>, the pattern produces:
+///
+/// %false = arith.constant dense<false> : vector<8x16xi1>
+///
+/// Slice [0,0]:
+/// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8
+/// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1>
+/// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+/// Slice [0,8]:
+/// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2
+/// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1>
+/// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+/// Slice [4,0]:
+/// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8
+/// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1>
+/// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+/// Slice [4,8]:
+/// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2
+/// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1>
+/// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1]
+/// : vector<4x8xi1> into vector<8x16xi1>
+struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> {
+ UnrollCreateMaskPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::CreateMaskOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp,
+ PatternRewriter &rewriter) const override {
+ auto targetShape = getTargetShape(options, createMaskOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType resultType = createMaskOp.getVectorType();
+ SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll();
+ Location loc = createMaskOp.getLoc();
+
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+ auto targetVectorType = VectorType::get(*targetShape, rewriter.getI1Type());
+ SmallVector<int64_t> strides(targetShape->size(), 1);
+
+ // In each dimension (d), each unrolled vector computes its mask size as:
+ // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]).
+ for (SmallVector<int64_t> offsets :
+ StaticTileOffsetRange(originalSize, *targetShape)) {
+ SmallVector<Value> unrolledOperands;
+
+ for (auto [i, originalMaskOperand] :
+ llvm::enumerate(createMaskOp.getOperands())) {
+ Value offsetVal =
+ arith::ConstantIndexOp::create(rewriter, loc, offsets[i]);
+ Value adjustedMaskSize = arith::SubIOp::create(
+ rewriter, loc, originalMaskOperand, offsetVal);
+ Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ Value unrolledDimSize =
+ arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]);
+ Value nonNegative =
+ arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero);
+ Value unrolledOperand =
+ arith::MinSIOp::create(rewriter, loc, nonNegative, unrolledDimSize);
+ unrolledOperands.push_back(unrolledOperand);
+ }
+
+ auto unrolledMask = vector::CreateMaskOp::create(
+ rewriter, loc, targetVectorType, unrolledOperands);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, unrolledMask, result, offsets, strides);
+ }
+ rewriter.replaceOp(createMaskOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
/// Checks whether extractShape is a contiguous slice of shape.
/// For extractShape to be contiguous in shape:
/// 1) All but the leading dimension of extractShape and shape must match
@@ -1202,8 +1292,9 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
- UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
- patterns.getContext(), options, benefit);
+ UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern,
+ UnrollCreateMaskPattern>(patterns.getContext(), options,
+ benefit);
}
void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index dec32e1c61a9b..70a34f227802e 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -497,6 +497,45 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
// CHECK-NOT: arith.addf
// CHECK: return
+func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1> {
+ %0 = vector.create_mask %size1, %size2 : vector<16x16xi1>
+ return %0 : vector<16x16xi1>
+}
+
+// CHECK-LABEL: func @vector_create_mask
+// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<16x16xi1>
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C8:.*]] = arith.constant 8 : index
+// CHECK: %[[MAX0:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index
+// CHECK: %[[MIN0:.*]] = arith.minsi %[[MAX0]], %[[C8]] : index
+// CHECK: %[[MAX1:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index
+// CHECK: %[[MIN1:.*]] = arith.minsi %[[MAX1]], %[[C8]] : index
+// CHECK: %[[MASK00:.*]] = vector.create_mask %[[MIN0]], %[[MIN1]] : vector<8x8xi1>
+// CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[MASK00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[MAX0_2:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index
+// CHECK: %[[MIN0_2:.*]] = arith.minsi %[[MAX0_2]], %[[C8]] : index
+// CHECK: %[[SUB1:.*]] = arith.subi %[[ARG1]], %[[C8]] : index
+// CHECK: %[[MAX1_2:.*]] = arith.maxsi %[[SUB1]], %[[C0]] : index
+// CHECK: %[[MIN1_2:.*]] = arith.minsi %[[MAX1_2]], %[[C8]] : index
+// CHECK: %[[MASK01:.*]] = vector.create_mask %[[MIN0_2]], %[[MIN1_2]] : vector<8x8xi1>
+// CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[SUB0:.*]] = arith.subi %[[ARG0]], %[[C8]] : index
+// CHECK: %[[MAX0_3:.*]] = arith.maxsi %[[SUB0]], %[[C0]] : index
+// CHECK: %[[MIN0_3:.*]] = arith.minsi %[[MAX0_3]], %[[C8]] : index
+// CHECK: %[[MAX1_3:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index
+// CHECK: %[[MIN1_3:.*]] = arith.minsi %[[MAX1_3]], %[[C8]] : index
+// CHECK: %[[MASK10:.*]] = vector.create_mask %[[MIN0_3]], %[[MIN1_3]] : vector<8x8xi1>
+// CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: %[[SUB0_2:.*]] = arith.subi %[[ARG0]], %[[C8]] : index
+// CHECK: %[[MAX0_4:.*]] = arith.maxsi %[[SUB0_2]], %[[C0]] : index
+// CHECK: %[[MIN0_4:.*]] = arith.minsi %[[MAX0_4]], %[[C8]] : index
+// CHECK: %[[SUB1_2:.*]] = arith.subi %[[ARG1]], %[[C8]] : index
+// CHECK: %[[MAX1_4:.*]] = arith.maxsi %[[SUB1_2]], %[[C0]] : index
+// CHECK: %[[MIN1_4:.*]] = arith.minsi %[[MAX1_4]], %[[C8]] : index
+// CHECK: %[[MASK11:.*]] = vector.create_mask %[[MIN0_4]], %[[MIN1_4]] : vector<8x8xi1>
+// CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1>
+// CHECK: return %[[INS11]] : vector<16x16xi1>
func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
%0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index e8ea0cc02d7f6..91e8cf72e64c3 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -178,6 +178,12 @@ struct TestVectorUnrollingPatterns
.setFilterConstraint([](Operation *op) {
return success(isa<vector::StepOp>(op));
}));
+ populateVectorUnrollPatterns(patterns,
+ UnrollVectorOptions()
+ .setNativeShape(ArrayRef<int64_t>{8, 8})
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::CreateMaskOp>(op));
+ }));
populateVectorUnrollPatterns(
patterns,
UnrollVectorOptions()
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
🐧 Linux x64 Test Results
|
amd-eochoalo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just that small nit. Thanks!
|
pinging for review @banach-space @kuhar |
| auto unrolledMask = vector::CreateMaskOp::create( | ||
| rewriter, loc, targetVectorType, unrolledOperands); | ||
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( | ||
| loc, unrolledMask, result, offsets, strides); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd need to use createOrFold for many if not all of the previous operations. For the example that you provided in the documentation, I'd expect everything to fold into a single constant mask op. We should make sure that such a folding happens in the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing that out..changed it
| // CHECK: return | ||
|
|
||
| func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1> { | ||
| %0 = vector.create_mask %size1, %size2 : vector<16x16xi1> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please, add a test with (full or partial?) constant indices to make sure that folding happens
|
@dcaballe any further comments? |
dcaballe
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, thanks!
This PR adds unrolling for vector.create_mask op based on the targetShape. Each unrolled vector computes its local mask size in each dimension (d) as: min(max(originalMaskSize[d] - offset[d], 0), unrolledMaskSize[d]).
This PR adds unrolling for vector.create_mask op based on the targetShape. Each unrolled vector computes its local mask size in each dimension (d) as:
min(max(originalMaskSize[d] - offset[d], 0), unrolledMaskSize[d]).