@@ -4096,6 +4096,75 @@ void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
40964096
40974097namespace {
40984098
4099+ // Pattern to rewrite an ExtractStridedSliceOp(CreateMaskOp) to
4100+ // CreateMaskOp.
4101+ //
4102+ // Example:
4103+ //
4104+ // %mask = vector.create_mask %ub : vector<16xi1>
4105+ // %slice = vector.extract_strided_slice [%offset] [8] [1]
4106+ //
4107+ // to
4108+ //
4109+ // %new_ub = arith.subi %ub, %offset
4110+ // %mask = vector.create_mask %new_ub : vector<8xi1>
4111+ class StridedSliceCreateMaskFolder final
4112+ : public OpRewritePattern<ExtractStridedSliceOp> {
4113+ using OpRewritePattern::OpRewritePattern;
4114+
4115+ public:
4116+ LogicalResult matchAndRewrite (ExtractStridedSliceOp extractStridedSliceOp,
4117+ PatternRewriter &rewriter) const override {
4118+ Location loc = extractStridedSliceOp.getLoc ();
4119+ // Return if 'extractStridedSliceOp' operand is not defined by a
4120+ // CreateMaskOp.
4121+ auto createMaskOp =
4122+ extractStridedSliceOp.getVector ().getDefiningOp <CreateMaskOp>();
4123+ if (!createMaskOp)
4124+ return failure ();
4125+ // Return if 'extractStridedSliceOp' has non-unit strides.
4126+ if (extractStridedSliceOp.hasNonUnitStrides ())
4127+ return failure ();
4128+ // Gather constant mask dimension sizes.
4129+ SmallVector<Value> maskDimSizes (createMaskOp.getOperands ());
4130+ // Gather strided slice offsets and sizes.
4131+ SmallVector<int64_t > sliceOffsets;
4132+ populateFromInt64AttrArray (extractStridedSliceOp.getOffsets (),
4133+ sliceOffsets);
4134+ SmallVector<int64_t > sliceSizes;
4135+ populateFromInt64AttrArray (extractStridedSliceOp.getSizes (), sliceSizes);
4136+
4137+ // Compute slice of vector mask region.
4138+ SmallVector<Value> sliceMaskDimSizes;
4139+ sliceMaskDimSizes.reserve (maskDimSizes.size ());
4140+ // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4141+ // only iterate on the leading dim sizes. The tail accounts for the
4142+ // remaining dim sizes.
4143+ for (auto [maskDimSize, sliceOffset, sliceSize] :
4144+ llvm::zip (maskDimSizes, sliceOffsets, sliceSizes)) {
4145+ // No need to clamp on min/max values, because create_mask has clamping
4146+ // semantics, i.e. the sliceMaskDimSize is allowed to be negative or
4147+ // greater than the vector dim size.
4148+ IntegerAttr offsetAttr =
4149+ rewriter.getIntegerAttr (maskDimSize.getType (), sliceOffset);
4150+ Value offset = rewriter.create <arith::ConstantOp>(loc, offsetAttr);
4151+ Value sliceMaskDimSize =
4152+ rewriter.create <arith::SubIOp>(loc, maskDimSize, offset);
4153+ sliceMaskDimSizes.push_back (sliceMaskDimSize);
4154+ }
4155+ // Add unchanged dimensions.
4156+ llvm::append_range (
4157+ sliceMaskDimSizes,
4158+ llvm::drop_begin (maskDimSizes, sliceMaskDimSizes.size ()));
4159+ // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
4160+ // region.
4161+ rewriter.replaceOpWithNewOp <CreateMaskOp>(
4162+ extractStridedSliceOp, extractStridedSliceOp.getResult ().getType (),
4163+ sliceMaskDimSizes);
4164+ return success ();
4165+ }
4166+ };
4167+
40994168// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
41004169// ConstantMaskOp.
41014170class StridedSliceConstantMaskFolder final
@@ -4117,14 +4186,14 @@ class StridedSliceConstantMaskFolder final
41174186 // Gather constant mask dimension sizes.
41184187 ArrayRef<int64_t > maskDimSizes = constantMaskOp.getMaskDimSizes ();
41194188 // Gather strided slice offsets and sizes.
4120- SmallVector<int64_t , 4 > sliceOffsets;
4189+ SmallVector<int64_t > sliceOffsets;
41214190 populateFromInt64AttrArray (extractStridedSliceOp.getOffsets (),
41224191 sliceOffsets);
4123- SmallVector<int64_t , 4 > sliceSizes;
4192+ SmallVector<int64_t > sliceSizes;
41244193 populateFromInt64AttrArray (extractStridedSliceOp.getSizes (), sliceSizes);
41254194
41264195 // Compute slice of vector mask region.
4127- SmallVector<int64_t , 4 > sliceMaskDimSizes;
4196+ SmallVector<int64_t > sliceMaskDimSizes;
41284197 sliceMaskDimSizes.reserve (maskDimSizes.size ());
41294198 for (auto [maskDimSize, sliceOffset, sliceSize] :
41304199 llvm::zip (maskDimSizes, sliceOffsets, sliceSizes)) {
@@ -4294,9 +4363,9 @@ void ExtractStridedSliceOp::getCanonicalizationPatterns(
42944363 RewritePatternSet &results, MLIRContext *context) {
42954364 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
42964365 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
4297- results.add <StridedSliceConstantMaskFolder, StridedSliceBroadcast ,
4298- StridedSliceSplat, ContiguousExtractStridedSliceToExtract>(
4299- context);
4366+ results.add <StridedSliceCreateMaskFolder, StridedSliceConstantMaskFolder ,
4367+ StridedSliceBroadcast, StridedSliceSplat,
4368+ ContiguousExtractStridedSliceToExtract>( context);
43004369}
43014370
43024371// ===----------------------------------------------------------------------===//
0 commit comments