-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[MLIR][Vector] Add unroll pattern for vector.create_mask #169119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1003,6 +1003,96 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> { | |
| vector::UnrollVectorOptions options; | ||
| }; | ||
|
|
||
| /// This pattern unrolls `vector.create_mask` operations into smaller mask | ||
| /// operations based on the target unroll shape. Each unrolled slice computes | ||
| /// its local mask size in each dimension (d) as: | ||
| /// min(max(originalMaskSize[d] - offset[d], 0), unrolledDimSize[d]). | ||
| /// Example: | ||
| /// Given a create_mask operation: | ||
| /// %0 = vector.create_mask %c6, %c10 : vector<8x16xi1> // mask first 6x10 | ||
| /// elements | ||
| /// | ||
| /// and a target unroll shape of <4x8>, the pattern produces: | ||
| /// | ||
| /// %false = arith.constant dense<false> : vector<8x16xi1> | ||
| /// | ||
| /// Slice [0,0]: | ||
| /// mask size = min(max(6-0, 0), 4) x min(max(10-0, 0), 8) = 4x8 | ||
| /// %mask00 = vector.create_mask %c4, %c8 : vector<4x8xi1> | ||
| /// %r0 = vector.insert_strided_slice %mask00, %false [0, 0], [1, 1] | ||
| /// : vector<4x8xi1> into vector<8x16xi1> | ||
| /// Slice [0,8]: | ||
| /// mask size = min(max(6-0, 0), 4) x min(max(10-8, 0), 8) = 4x2 | ||
| /// %mask01 = vector.create_mask %c4, %c2 : vector<4x8xi1> | ||
| /// %r1 = vector.insert_strided_slice %mask01, %r0 [0, 8], [1, 1] | ||
| /// : vector<4x8xi1> into vector<8x16xi1> | ||
| /// Slice [4,0]: | ||
| /// mask size = min(max(6-4, 0), 4) x min(max(10-0, 0), 8) = 2x8 | ||
| /// %mask10 = vector.create_mask %c2, %c8 : vector<4x8xi1> | ||
| /// %r2 = vector.insert_strided_slice %mask10, %r1 [4, 0], [1, 1] | ||
| /// : vector<4x8xi1> into vector<8x16xi1> | ||
| /// Slice [4,8]: | ||
| /// mask size = min(max(6-4, 0), 4) x min(max(10-8, 0), 8) = 2x2 | ||
| /// %mask11 = vector.create_mask %c2, %c2 : vector<4x8xi1> | ||
| /// %result = vector.insert_strided_slice %mask11, %r2 [4, 8], [1, 1] | ||
| /// : vector<4x8xi1> into vector<8x16xi1> | ||
| struct UnrollCreateMaskPattern : public OpRewritePattern<vector::CreateMaskOp> { | ||
| UnrollCreateMaskPattern(MLIRContext *context, | ||
| const vector::UnrollVectorOptions &options, | ||
| PatternBenefit benefit = 1) | ||
| : OpRewritePattern<vector::CreateMaskOp>(context, benefit), | ||
| options(options) {} | ||
|
|
||
| LogicalResult matchAndRewrite(vector::CreateMaskOp createMaskOp, | ||
| PatternRewriter &rewriter) const override { | ||
| auto targetShape = getTargetShape(options, createMaskOp); | ||
| if (!targetShape) | ||
| return failure(); | ||
|
|
||
| VectorType resultType = createMaskOp.getVectorType(); | ||
| SmallVector<int64_t> originalSize = *createMaskOp.getShapeForUnroll(); | ||
| Location loc = createMaskOp.getLoc(); | ||
|
|
||
| Value result = arith::ConstantOp::create(rewriter, loc, resultType, | ||
| rewriter.getZeroAttr(resultType)); | ||
| auto targetVectorType = VectorType::get(*targetShape, rewriter.getI1Type()); | ||
| SmallVector<int64_t> strides(targetShape->size(), 1); | ||
|
|
||
| // In each dimension (d), each unrolled vector computes its mask size as: | ||
| // min(max(originalMaskOperands[d] - offset[d], 0), unrolledDimSize[d]). | ||
| for (SmallVector<int64_t> offsets : | ||
| StaticTileOffsetRange(originalSize, *targetShape)) { | ||
| SmallVector<Value> unrolledOperands; | ||
|
|
||
| for (auto [i, originalMaskOperand] : | ||
| llvm::enumerate(createMaskOp.getOperands())) { | ||
| Value offsetVal = | ||
| arith::ConstantIndexOp::create(rewriter, loc, offsets[i]); | ||
| Value adjustedMaskSize = arith::SubIOp::create( | ||
| rewriter, loc, originalMaskOperand, offsetVal); | ||
| Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); | ||
| Value unrolledDimSize = | ||
| arith::ConstantIndexOp::create(rewriter, loc, (*targetShape)[i]); | ||
| Value nonNegative = | ||
| arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero); | ||
| Value unrolledOperand = | ||
| arith::MinSIOp::create(rewriter, loc, nonNegative, unrolledDimSize); | ||
| unrolledOperands.push_back(unrolledOperand); | ||
| } | ||
|
|
||
| auto unrolledMask = vector::CreateMaskOp::create( | ||
| rewriter, loc, targetVectorType, unrolledOperands); | ||
| result = rewriter.createOrFold<vector::InsertStridedSliceOp>( | ||
| loc, unrolledMask, result, offsets, strides); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'd need to use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing that out..changed it |
||
| } | ||
| rewriter.replaceOp(createMaskOp, result); | ||
| return success(); | ||
| } | ||
|
|
||
| private: | ||
| vector::UnrollVectorOptions options; | ||
| }; | ||
|
|
||
| /// Checks whether extractShape is a contiguous slice of shape. | ||
| /// For extractShape to be contiguous in shape: | ||
| /// 1) All but the leading dimension of extractShape and shape must match | ||
|
|
@@ -1202,8 +1292,9 @@ void mlir::vector::populateVectorUnrollPatterns( | |
| UnrollReductionPattern, UnrollMultiReductionPattern, | ||
| UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern, | ||
| UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements, | ||
| UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>( | ||
| patterns.getContext(), options, benefit); | ||
| UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern, | ||
| UnrollCreateMaskPattern>(patterns.getContext(), options, | ||
| benefit); | ||
| } | ||
|
|
||
| void mlir::vector::populateVectorToElementsUnrollPatterns( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -497,6 +497,45 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3 | |
| // CHECK-NOT: arith.addf | ||
| // CHECK: return | ||
|
|
||
| func.func @vector_create_mask(%size1: index, %size2: index) -> vector<16x16xi1> { | ||
| %0 = vector.create_mask %size1, %size2 : vector<16x16xi1> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please, add a test with (full or partial?) constant indices to make sure that folding happens |
||
| return %0 : vector<16x16xi1> | ||
| } | ||
|
|
||
| // CHECK-LABEL: func @vector_create_mask | ||
| // CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<16x16xi1> | ||
| // CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16x16xi1> | ||
| // CHECK: %[[C0:.*]] = arith.constant 0 : index | ||
| // CHECK: %[[C8:.*]] = arith.constant 8 : index | ||
| // CHECK: %[[MAX0:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index | ||
| // CHECK: %[[MIN0:.*]] = arith.minsi %[[MAX0]], %[[C8]] : index | ||
| // CHECK: %[[MAX1:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index | ||
| // CHECK: %[[MIN1:.*]] = arith.minsi %[[MAX1]], %[[C8]] : index | ||
| // CHECK: %[[MASK00:.*]] = vector.create_mask %[[MIN0]], %[[MIN1]] : vector<8x8xi1> | ||
| // CHECK: %[[INS00:.*]] = vector.insert_strided_slice %[[MASK00]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> | ||
| // CHECK: %[[MAX0_2:.*]] = arith.maxsi %[[ARG0]], %[[C0]] : index | ||
| // CHECK: %[[MIN0_2:.*]] = arith.minsi %[[MAX0_2]], %[[C8]] : index | ||
| // CHECK: %[[SUB1:.*]] = arith.subi %[[ARG1]], %[[C8]] : index | ||
| // CHECK: %[[MAX1_2:.*]] = arith.maxsi %[[SUB1]], %[[C0]] : index | ||
| // CHECK: %[[MIN1_2:.*]] = arith.minsi %[[MAX1_2]], %[[C8]] : index | ||
| // CHECK: %[[MASK01:.*]] = vector.create_mask %[[MIN0_2]], %[[MIN1_2]] : vector<8x8xi1> | ||
| // CHECK: %[[INS01:.*]] = vector.insert_strided_slice %[[MASK01]], %[[INS00]] {offsets = [0, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> | ||
| // CHECK: %[[SUB0:.*]] = arith.subi %[[ARG0]], %[[C8]] : index | ||
| // CHECK: %[[MAX0_3:.*]] = arith.maxsi %[[SUB0]], %[[C0]] : index | ||
| // CHECK: %[[MIN0_3:.*]] = arith.minsi %[[MAX0_3]], %[[C8]] : index | ||
| // CHECK: %[[MAX1_3:.*]] = arith.maxsi %[[ARG1]], %[[C0]] : index | ||
| // CHECK: %[[MIN1_3:.*]] = arith.minsi %[[MAX1_3]], %[[C8]] : index | ||
| // CHECK: %[[MASK10:.*]] = vector.create_mask %[[MIN0_3]], %[[MIN1_3]] : vector<8x8xi1> | ||
| // CHECK: %[[INS10:.*]] = vector.insert_strided_slice %[[MASK10]], %[[INS01]] {offsets = [8, 0], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> | ||
| // CHECK: %[[SUB0_2:.*]] = arith.subi %[[ARG0]], %[[C8]] : index | ||
| // CHECK: %[[MAX0_4:.*]] = arith.maxsi %[[SUB0_2]], %[[C0]] : index | ||
| // CHECK: %[[MIN0_4:.*]] = arith.minsi %[[MAX0_4]], %[[C8]] : index | ||
| // CHECK: %[[SUB1_2:.*]] = arith.subi %[[ARG1]], %[[C8]] : index | ||
| // CHECK: %[[MAX1_4:.*]] = arith.maxsi %[[SUB1_2]], %[[C0]] : index | ||
| // CHECK: %[[MIN1_4:.*]] = arith.minsi %[[MAX1_4]], %[[C8]] : index | ||
| // CHECK: %[[MASK11:.*]] = vector.create_mask %[[MIN0_4]], %[[MIN1_4]] : vector<8x8xi1> | ||
| // CHECK: %[[INS11:.*]] = vector.insert_strided_slice %[[MASK11]], %[[INS10]] {offsets = [8, 8], strides = [1, 1]} : vector<8x8xi1> into vector<16x16xi1> | ||
| // CHECK: return %[[INS11]] : vector<16x16xi1> | ||
|
|
||
| func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> { | ||
| %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32> | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.