Skip to content

Commit faee3e5

Browse files
committed
address comments
1 parent 8dc4dda commit faee3e5

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4103,8 +4103,8 @@ class StridedSliceCreateMaskFolder final
41034103
Location loc = extractStridedSliceOp.getLoc();
41044104
// Return if 'extractStridedSliceOp' operand is not defined by a
41054105
// CreateMaskOp.
4106-
auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
4107-
auto createMaskOp = dyn_cast_or_null<CreateMaskOp>(defOp);
4106+
auto createMaskOp =
4107+
extractStridedSliceOp.getVector().getDefiningOp<CreateMaskOp>();
41084108
if (!createMaskOp)
41094109
return failure();
41104110
// Return if 'extractStridedSliceOp' has non-unit strides.
@@ -4122,6 +4122,9 @@ class StridedSliceCreateMaskFolder final
41224122
// Compute slice of vector mask region.
41234123
SmallVector<Value> sliceMaskDimSizes;
41244124
sliceMaskDimSizes.reserve(maskDimSizes.size());
4125+
// sliceOffsets.size() <= maskDimSizes.size(), so we use llvm::zip and
4126+
// only iterate on the leading dim sizes. The tail accounts for the
4127+
// remaining dim sizes.
41254128
for (auto [maskDimSize, sliceOffset, sliceSize] :
41264129
llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
41274130
// No need to clamp on min/max values, because create_mask has clamping
@@ -4135,12 +4138,9 @@ class StridedSliceCreateMaskFolder final
41354138
sliceMaskDimSizes.push_back(sliceMaskDimSize);
41364139
}
41374140
// Add unchanged dimensions.
4138-
if (sliceMaskDimSizes.size() < maskDimSizes.size()) {
4139-
for (size_t i = sliceMaskDimSizes.size(), e = maskDimSizes.size(); i < e;
4140-
++i) {
4141-
sliceMaskDimSizes.push_back(maskDimSizes[i]);
4142-
}
4143-
}
4141+
llvm::append_range(
4142+
sliceMaskDimSizes,
4143+
llvm::drop_begin(maskDimSizes, sliceMaskDimSizes.size()));
41444144
// Replace 'extractStridedSliceOp' with CreateMaskOp with sliced mask
41454145
// region.
41464146
rewriter.replaceOpWithNewOp<CreateMaskOp>(

0 commit comments

Comments
 (0)