@@ -4103,8 +4103,8 @@ class StridedSliceCreateMaskFolder final
41034103 Location loc = extractStridedSliceOp.getLoc ();
41044104 // Return if 'extractStridedSliceOp' operand is not defined by a
41054105 // CreateMaskOp.
4106- auto *defOp = extractStridedSliceOp. getVector (). getDefiningOp ();
4107- auto createMaskOp = dyn_cast_or_null <CreateMaskOp>(defOp );
4106+ auto createMaskOp =
4107+ extractStridedSliceOp. getVector (). getDefiningOp <CreateMaskOp>();
41084108 if (!createMaskOp)
41094109 return failure ();
41104110 // Return if 'extractStridedSliceOp' has non-unit strides.
@@ -4122,6 +4122,9 @@ class StridedSliceCreateMaskFolder final
41224122 // Compute slice of vector mask region.
41234123 SmallVector<Value> sliceMaskDimSizes;
41244124 sliceMaskDimSizes.reserve (maskDimSizes.size ());
4125+ // sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4126+ // only iterate on the leading dim sizes. The tail accounts for the
4127+ // remaining dim sizes.
41254128 for (auto [maskDimSize, sliceOffset, sliceSize] :
41264129 llvm::zip (maskDimSizes, sliceOffsets, sliceSizes)) {
41274130 // No need to clamp on min/max values, because create_mask has clamping
@@ -4135,12 +4138,9 @@ class StridedSliceCreateMaskFolder final
41354138 sliceMaskDimSizes.push_back (sliceMaskDimSize);
41364139 }
41374140 // Add unchanged dimensions.
4138- if (sliceMaskDimSizes.size () < maskDimSizes.size ()) {
4139- for (size_t i = sliceMaskDimSizes.size (), e = maskDimSizes.size (); i < e;
4140- ++i) {
4141- sliceMaskDimSizes.push_back (maskDimSizes[i]);
4142- }
4143- }
4141+ llvm::append_range (
4142+ sliceMaskDimSizes,
4143+ llvm::drop_begin (maskDimSizes, sliceMaskDimSizes.size ()));
41444144 // Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
41454145 // region.
41464146 rewriter.replaceOpWithNewOp <CreateMaskOp>(
0 commit comments