diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 5e8421ed67d66..3f5564541554e 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2559,7 +2559,6 @@ def Vector_MaskOp : Vector_Op<"mask", [ Location loc); }]; - let hasCanonicalizer = 1; let hasFolder = 1; let hasCustomAssemblyFormat = 1; let hasVerifier = 1; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index bbb366b01fa6e..e404139469b96 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6646,13 +6646,40 @@ LogicalResult MaskOp::verify() { return success(); } -/// Folds vector.mask ops with an all-true mask. +/// Folds empty `vector.mask` with no passthru operand and with or without +/// return values. For example: +/// +/// %0 = vector.mask %mask { vector.yield %a : vector<8xf32> } : +/// vector<8xi1> -> vector<8xf32> +/// %1 = user_op %0 : vector<8xf32> +/// +/// becomes: +/// +/// %0 = user_op %a : vector<8xf32> +/// +static LogicalResult foldEmptyMaskOp(MaskOp maskOp, MaskOp::FoldAdaptor adaptor, + SmallVectorImpl &results) { + if (!maskOp.isEmpty() || maskOp.hasPassthru()) + return failure(); + + Block *block = maskOp.getMaskBlock(); + auto terminator = cast(block->front()); + if (terminator.getNumOperands() == 0) { + // `vector.mask` has no results, just remove the `vector.mask`. + return success(); + } + + // `vector.mask` has results, propagate the results. + llvm::append_range(results, terminator.getOperands()); + return success(); +} + LogicalResult MaskOp::fold(FoldAdaptor adaptor, SmallVectorImpl &results) { - MaskFormat maskFormat = getMaskFormat(getMask()); - if (isEmpty()) - return failure(); + if (succeeded(foldEmptyMaskOp(*this, adaptor, results))) + return success(); + MaskFormat maskFormat = getMaskFormat(getMask()); if (maskFormat != MaskFormat::AllTrue) return failure(); @@ -6665,37 +6692,6 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor, return success(); } -// Elides empty vector.mask operations with or without return values. Propagates -// the yielded values by the vector.yield terminator, if any, or erases the op, -// otherwise. -class ElideEmptyMaskOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(MaskOp maskOp, - PatternRewriter &rewriter) const override { - auto maskingOp = cast(maskOp.getOperation()); - if (maskingOp.getMaskableOp()) - return failure(); - - if (!maskOp.isEmpty()) - return failure(); - - Block *block = maskOp.getMaskBlock(); - auto terminator = cast(block->front()); - if (terminator.getNumOperands() == 0) - rewriter.eraseOp(maskOp); - else - rewriter.replaceOp(maskOp, terminator.getOperands()); - - return success(); - } -}; - -void MaskOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - // MaskingOpInterface definitions. /// Returns the operation masked by this 'vector.mask'. diff --git a/mlir/test/Conversion/GPUCommon/lower-vector.mlir b/mlir/test/Conversion/GPUCommon/lower-vector.mlir index 532a2383cea9e..b4e3da9d0dbfe 100644 --- a/mlir/test/Conversion/GPUCommon/lower-vector.mlir +++ b/mlir/test/Conversion/GPUCommon/lower-vector.mlir @@ -1,10 +1,12 @@ // RUN: mlir-opt %s --gpu-to-llvm | FileCheck %s module { + // CHECK-LABEL: func @func + // CHECK-SAME: %[[IN:.*]]: vector<11xf32> func.func @func(%arg: vector<11xf32>) -> vector<11xf32> { %cst_41 = arith.constant dense : vector<11xi1> - // CHECK: vector.mask - // CHECK-SAME: vector.yield %arg0 + // CHECK-NOT: vector.mask + // CHECK: return %[[IN]] : vector<11xf32> %127 = vector.mask %cst_41 { vector.yield %arg : vector<11xf32> } : vector<11xi1> -> vector<11xf32> return %127 : vector<11xf32> }