-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][Vector] Move vector.mask canonicalization to folder
#140324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6646,13 +6646,42 @@ LogicalResult MaskOp::verify() { | |
| return success(); | ||
| } | ||
|
|
||
| /// Folds vector.mask ops with an all-true mask. | ||
| /// Folds empty `vector.mask` with no passthru operand and with or without | ||
| /// return values. For example: | ||
| /// | ||
| /// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : | ||
| /// vector<8xi1> -> vector<8xf32> | ||
| /// %1 = user_op %0 : vector<8xf32> | ||
| /// | ||
| /// becomes: | ||
| /// | ||
| /// %0 = user_op %a : vector<8xf32> | ||
| /// | ||
| /// `vector.mask` with a passthru is handled by the canonicalizer. | ||
|
||
| /// | ||
| static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor, | ||
| SmallVectorImpl<OpFoldResult> &results) { | ||
| if (!maskOp.isEmpty() || maskOp.hasPassthru()) | ||
| return failure(); | ||
|
|
||
| Block *block = maskOp.getMaskBlock(); | ||
| auto terminator = cast<vector::YieldOp>(block->front()); | ||
| if (terminator.getNumOperands() == 0) { | ||
| // `vector.mask` has no results, just remove the `vector.mask`. | ||
| return success(); | ||
| } | ||
|
|
||
| // `vector.mask` has results, propagate the results. | ||
| llvm::append_range(results, terminator.getOperands()); | ||
| return success(); | ||
| } | ||
|
|
||
| LogicalResult MaskOp::fold(FoldAdaptor adaptor, | ||
| SmallVectorImpl<OpFoldResult> &results) { | ||
| MaskFormat maskFormat = getMaskFormat(getMask()); | ||
| if (isEmpty()) | ||
| return failure(); | ||
| if (succeeded(foldEmptyMaskOp(*this, adaptor, results))) | ||
| return success(); | ||
|
|
||
| MaskFormat maskFormat = getMaskFormat(getMask()); | ||
| if (maskFormat != MaskFormat::AllTrue) | ||
| return failure(); | ||
|
|
||
|
|
@@ -6665,37 +6694,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor, | |
| return success(); | ||
| } | ||
|
|
||
| // Elides empty vector.mask operations with or without return values. Propagates | ||
| // the yielded values by the vector.yield terminator, if any, or erases the op, | ||
| // otherwise. | ||
| class ElideEmptyMaskOp : public OpRewritePattern<MaskOp> { | ||
| using OpRewritePattern::OpRewritePattern; | ||
|
|
||
| LogicalResult matchAndRewrite(MaskOp maskOp, | ||
| PatternRewriter &rewriter) const override { | ||
| auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation()); | ||
| if (maskingOp.getMaskableOp()) | ||
| return failure(); | ||
|
|
||
| if (!maskOp.isEmpty()) | ||
| return failure(); | ||
|
|
||
| Block *block = maskOp.getMaskBlock(); | ||
| auto terminator = cast<vector::YieldOp>(block->front()); | ||
| if (terminator.getNumOperands() == 0) | ||
| rewriter.eraseOp(maskOp); | ||
| else | ||
| rewriter.replaceOp(maskOp, terminator.getOperands()); | ||
|
|
||
| return success(); | ||
| } | ||
| }; | ||
|
|
||
| void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results, | ||
| MLIRContext *context) { | ||
| results.add<ElideEmptyMaskOp>(context); | ||
| } | ||
|
|
||
| // MaskingOpInterface definitions. | ||
|
|
||
| /// Returns the operation masked by this 'vector.mask'. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,10 +1,12 @@ | ||
| // RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s | ||
|
|
||
| module { | ||
| // CHECK-LABEL: func @func | ||
| // CHECK-SAME: %[[IN:.*]]: vector<11xf32> | ||
| func.func @func(%arg: vector<11xf32>) -> vector<11xf32> { | ||
| %cst_41 = arith.constant dense<true> : vector<11xi1> | ||
| // CHECK: vector.mask | ||
| // CHECK-SAME: vector.yield %arg0 | ||
| // CHECK-NOT: vector.mask | ||
| // CHECK: return %[[IN]] : vector<11xf32> | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what the intent of this test is but I'm updating it accordingly
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a regression test originally added in 6904998 . It’s a great example of why "regression" on its own isn’t a helpful test description 😅 That commit fixed a bug in From the context of the original commit, it seems the intent was to test the lowering of empty vector.mask. If so, I think this would be better placed in: mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir WDYT? Would you mind moving it there instead? Also - if I’ve understood the original intent correctly, your update here makes sense 🙂 |
||
| %127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32> | ||
| return %127 : vector<11xf32> | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this line anymore? I thibk we can remove it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I'm adding that back in my next PR but I removed it from now.