@@ -6650,13 +6650,40 @@ LogicalResult MaskOp::verify() {
66506650 return success ();
66516651}
66526652
6653- // / Folds vector.mask ops with an all-true mask.
6653+ // / Folds empty `vector.mask` with no passthru operand and with or without
6654+ // / return values. For example:
6655+ // /
6656+ // / %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } :
6657+ // / vector<8xi1> -> vector<8xf32>
6658+ // / %1 = user_op %0 : vector<8xf32>
6659+ // /
6660+ // / becomes:
6661+ // /
6662+ // / %0 = user_op %a : vector<8xf32>
6663+ // /
6664+ static LogicalResult foldEmptyMaskOp (MaskOp maskOp, MaskOp::FoldAdaptor adaptor,
6665+ SmallVectorImpl<OpFoldResult> &results) {
6666+ if (!maskOp.isEmpty () || maskOp.hasPassthru ())
6667+ return failure ();
6668+
6669+ Block *block = maskOp.getMaskBlock ();
6670+ auto terminator = cast<vector::YieldOp>(block->front ());
6671+ if (terminator.getNumOperands () == 0 ) {
6672+ // `vector.mask` has no results, just remove the `vector.mask`.
6673+ return success ();
6674+ }
6675+
6676+ // `vector.mask` has results, propagate the results.
6677+ llvm::append_range (results, terminator.getOperands ());
6678+ return success ();
6679+ }
6680+
66546681LogicalResult MaskOp::fold (FoldAdaptor adaptor,
66556682 SmallVectorImpl<OpFoldResult> &results) {
6656- MaskFormat maskFormat = getMaskFormat (getMask ());
6657- if (isEmpty ())
6658- return failure ();
6683+ if (succeeded (foldEmptyMaskOp (*this , adaptor, results)))
6684+ return success ();
66596685
6686+ MaskFormat maskFormat = getMaskFormat (getMask ());
66606687 if (maskFormat != MaskFormat::AllTrue)
66616688 return failure ();
66626689
@@ -6669,37 +6696,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
66696696 return success ();
66706697}
66716698
6672- // Elides empty vector.mask operations with or without return values. Propagates
6673- // the yielded values by the vector.yield terminator, if any, or erases the op,
6674- // otherwise.
6675- class ElideEmptyMaskOp : public OpRewritePattern <MaskOp> {
6676- using OpRewritePattern::OpRewritePattern;
6677-
6678- LogicalResult matchAndRewrite (MaskOp maskOp,
6679- PatternRewriter &rewriter) const override {
6680- auto maskingOp = cast<MaskingOpInterface>(maskOp.getOperation ());
6681- if (maskingOp.getMaskableOp ())
6682- return failure ();
6683-
6684- if (!maskOp.isEmpty ())
6685- return failure ();
6686-
6687- Block *block = maskOp.getMaskBlock ();
6688- auto terminator = cast<vector::YieldOp>(block->front ());
6689- if (terminator.getNumOperands () == 0 )
6690- rewriter.eraseOp (maskOp);
6691- else
6692- rewriter.replaceOp (maskOp, terminator.getOperands ());
6693-
6694- return success ();
6695- }
6696- };
6697-
6698- void MaskOp::getCanonicalizationPatterns (RewritePatternSet &results,
6699- MLIRContext *context) {
6700- results.add <ElideEmptyMaskOp>(context);
6701- }
6702-
67036699// MaskingOpInterface definitions.
67046700
67056701// / Returns the operation masked by this 'vector.mask'.
0 commit comments