-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][NFC] Simplify constant checks with isOneInteger and renamed isZeroInteger. #139340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-scf Author: Han-Chung Wang (hanhanW) ChangesThe revision adds isOneIndex helper, and simplifies the existing code with the two methods. It removes some lambda, which makes code cleaner. Full diff: https://github.com/llvm/llvm-project/pull/139340.diff 18 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 2a3a2defb810d..ea1a2384f8cba 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -28,6 +28,10 @@ namespace mlir {
/// with attribute with value `0`.
bool isZeroIndex(OpFoldResult v);
+/// Return true if `v` is an IntegerAttr with value `1` of a ConstantIndexOp
+/// with attribute with value `1`.
+bool isOneIndex(OpFoldResult v);
+
/// Represents a range (offset, size, and stride) where each element of the
/// triple may be dynamic or static.
struct Range {
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 04bc62262c3d8..c9e7ae6f8bdb5 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -897,7 +897,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
OpFoldResult offset =
getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
.front();
- if (isConstantIntValue(offset, 0)) {
+ if (isZeroIndex(offset)) {
rewriter.replaceOp(op, src);
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index fce0751430305..a6b1e21cd3b53 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4426,8 +4426,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// Return true if we have a zero-value tile.
auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
- return llvm::any_of(
- tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
+ return llvm::any_of(tiles, isZeroIndex);
};
// Verify tiles. Do not allow zero tiles.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f6ca109b84f9e..25b0635220f3b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3315,10 +3315,7 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
SmallVector<OpFoldResult> steps = loop.getMixedStep();
- if (llvm::all_of(
- lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
- llvm::all_of(
- steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
+ if (llvm::all_of(lbs, isZeroIndex) && llvm::all_of(steps, isOneIndex)) {
return loop;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 7c2788f16a3b6..700be3ad35705 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -23,6 +23,7 @@
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinOps.h"
@@ -376,13 +377,13 @@ static void calculateTileOffsetsAndSizes(
SmallVector<Value> threadIds = forallOp.getInductionVars();
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
- numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
+ numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); });
int64_t nLoops = loopRanges.size();
tiledOffsets.reserve(nLoops);
tiledSizes.reserve(nLoops);
for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
bool overflow = loopIdx >= numThreads.size();
- bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
+ bool isZero = !overflow && isZeroIndex(numThreads[loopIdx]);
// Degenerate case: take the whole domain.
if (overflow || isZero) {
tiledOffsets.push_back(loopRanges[loopIdx].offset);
@@ -413,7 +414,7 @@ static void calculateTileOffsetsAndSizes(
OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
b, loc, i + j * m - n,
{offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
- if (!isConstantIntValue(residualTileSize, 0)) {
+ if (!isZeroIndex(residualTileSize)) {
OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
b, loc, -i + m, {offsetPerThread, size});
tileSizePerThread =
@@ -655,7 +656,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
Operation *tiledOp = nullptr;
SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
- numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
+ numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); });
SmallVector<Value> materializedNonZeroNumThreads =
getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 0cc840403a020..faae77a6eecb3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -732,7 +732,7 @@ struct PackOpTiling
// iterated or inner dims are not tiled. Otherwise, it will generate a
// sequence of non-trivial ops (for partial tiles).
for (auto offset : offsets.take_back(numTiles))
- if (!isConstantIntValue(offset, 0))
+ if (!isZeroIndex(offset))
return failure();
for (auto iter :
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a0237c18cf2fe..1175c57694272 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1889,9 +1889,7 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
// reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
// are 0.
if (auto prev = src.getDefiningOp<SubViewOp>())
- if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
- return isConstantIntValue(val, 0);
- }))
+ if (llvm::all_of(prev.getMixedOffsets(), isZeroIndex))
return prev.getSource();
return nullptr;
@@ -3285,11 +3283,9 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
auto srcSizes = srcSubview.getMixedSizes();
auto sizes = getMixedSizes();
auto offsets = getMixedOffsets();
- bool allOffsetsZero = llvm::all_of(
- offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
+ bool allOffsetsZero = llvm::all_of(offsets, isZeroIndex);
auto strides = getMixedStrides();
- bool allStridesOne = llvm::all_of(
- strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
+ bool allStridesOne = llvm::all_of(strides, isOneIndex);
bool allSizesSame = llvm::equal(sizes, srcSizes);
if (allOffsetsZero && allStridesOne && allSizesSame &&
resultMemrefType == sourceMemrefType)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
index 05ba6a3f38708..e28f7d3e4924a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
@@ -251,10 +251,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
// to do.
SmallVector<OpFoldResult> indices =
getAsOpFoldResult(loadStoreLikeOp.getIndices());
- if (std::all_of(indices.begin(), indices.end(),
- [](const OpFoldResult &opFold) {
- return isConstantIntValue(opFold, 0);
- })) {
+ if (std::all_of(indices.begin(), indices.end(), isZeroIndex)) {
return rewriter.notifyMatchFailure(
loadStoreLikeOp, "no computation to extract: offsets are 0s");
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 0cd7da5db9163..d7d42219bc7b6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -133,7 +133,7 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
tileSizes.resize(numLoops, zero);
for (auto [index, range, nt] :
llvm::enumerate(iterationDomain, numThreads)) {
- if (isConstantIntValue(nt, 0))
+ if (isZeroIndex(nt))
continue;
tileSizes[index] = affine::makeComposedFoldedAffineApply(
@@ -265,7 +265,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
// Non-tiled cases, set the offset and size to the
// `loopRange.offset/size`.
- if (isConstantIntValue(nt, 0)) {
+ if (isZeroIndex(nt)) {
offsets.push_back(loopRange.offset);
sizes.push_back(loopRange.size);
continue;
@@ -280,7 +280,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
{loopRange.offset, nt, tileSize, loopRange.size});
OpFoldResult size = tileSize;
- if (!isConstantIntValue(residualTileSize, 0)) {
+ if (!isZeroIndex(residualTileSize)) {
OpFoldResult sizeMinusOffsetPerThread =
affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
{offset, loopRange.size});
@@ -316,7 +316,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
// Non-tiled cases, set the offset and size to the
// `loopRange.offset/size`.
- if (isConstantIntValue(tileSize, 0)) {
+ if (isZeroIndex(tileSize)) {
offsets.push_back(loopRange.offset);
sizes.push_back(loopRange.size);
continue;
@@ -341,7 +341,7 @@ getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
SmallVector<OpFoldResult> lbs, ubs, steps;
for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
// No loop if the tile size is 0.
- if (isConstantIntValue(tileSize, 0))
+ if (isZeroIndex(tileSize))
continue;
lbs.push_back(loopRange.offset);
ubs.push_back(loopRange.size);
@@ -495,7 +495,7 @@ static LogicalResult generateLoopNestUsingForallOp(
// Prune the zero numthreads.
SmallVector<OpFoldResult> nonZeroNumThreads;
for (auto nt : numThreads) {
- if (isConstantIntValue(nt, 0))
+ if (isZeroIndex(nt))
continue;
nonZeroNumThreads.push_back(nt);
}
@@ -1290,9 +1290,7 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
sliceSizes = sliceOp.getMixedSizes();
// expect all strides of sliceOp being 1
- if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
- return !isConstantIntValue(ofr, 1);
- }))
+ if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
return failure();
unsigned sliceResultNumber =
@@ -2114,9 +2112,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(
SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
// 9. Check all insert stride is 1.
- if (llvm::any_of(strides, [](OpFoldResult stride) {
- return !isConstantIntValue(stride, 1);
- })) {
+ if (!llvm::all_of(strides, isOneIndex)) {
return rewriter.notifyMatchFailure(
candidateSliceOp, "containingOp's result yield with stride");
}
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index d9550fe18dc02..f95e38fc75c8d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -768,7 +768,7 @@ static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
// If an `affine.apply` operation is generated for denormalization, the use
// of `origLb` in those ops must not be replaced. These arent not generated
// when `origLb == 0` and `origStep == 1`.
- if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
+ if (!isZeroIndex(origLb) || !isOneIndex(origStep)) {
if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
preservedUses.insert(preservedUse);
}
@@ -785,8 +785,8 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
}
Value denormalizedIv;
SmallPtrSet<Operation *, 2> preserve;
- bool isStepOne = isConstantIntValue(origStep, 1);
- bool isZeroBased = isConstantIntValue(origLb, 0);
+ bool isStepOne = isOneIndex(origStep);
+ bool isZeroBased = isZeroIndex(origLb);
Value scaled = normalizedIv;
if (!isStepOne) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index b2eca539194a8..649375b4c4037 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -614,7 +614,7 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
// Check for single block, unit-stride for-loop that is generated by
// sparsifier, which means no data dependence analysis is required,
// and its loop-body is very restricted in form.
- if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
+ if (!op.getRegion().hasOneBlock() || !isOneIndex(op.getStep()) ||
!op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
return failure();
// Analyze (!codegen) and rewrite (codegen) loop-body.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 29da32cd1791c..717ea1d0d7618 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2738,8 +2738,7 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
return getResult();
if (auto result = foldInsertAfterExtractSlice(*this))
return result;
- if (llvm::any_of(getMixedSizes(),
- [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
+ if (llvm::any_of(getMixedSizes(), isZeroIndex))
return getDest();
return OpFoldResult();
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 7778a02dbeaf4..41407064cb6d7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -135,9 +135,9 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
SmallVector<OpFoldResult> newStrides(rank, b.getIndexAttr(1));
for (unsigned dim = 0; dim < rank; ++dim) {
auto low = padOp.getMixedLowPad()[dim];
- bool hasLowPad = !isConstantIntValue(low, 0);
+ bool hasLowPad = !isZeroIndex(low);
auto high = padOp.getMixedHighPad()[dim];
- bool hasHighPad = !isConstantIntValue(high, 0);
+ bool hasHighPad = !isZeroIndex(high);
auto offset = offsets[dim];
auto length = sizes[dim];
// If the dim has no padding, we dont need to calculate new values for that
@@ -208,7 +208,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
// Check if newLength is zero. In that case, no SubTensorOp should be
// executed.
- if (isConstantIntValue(newLength, 0)) {
+ if (isZeroIndex(newLength)) {
hasZeroLen = true;
} else if (!hasZeroLen) {
Value check = b.create<arith::CmpIOp>(
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index a3de7f9b44ae6..9978aac1ee80e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -452,7 +452,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
isZeroOffsetAndFullSize =
[](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
- if (!isConstantIntValue(offset, 0))
+ if (!isZeroIndex(offset))
return false;
FailureOr<bool> maybeEqual =
ValueBoundsConstraintSet::areEqual(sliceSize, size);
@@ -476,7 +476,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
// Find the first expanded dim after the first dim with non-unit extracted
// size.
for (; i < e; ++i) {
- if (!isConstantIntValue(sizes[indices[i]], 1)) {
+ if (!isOneIndex(sizes[indices[i]])) {
// +1 to skip the first non-unit size dim.
i++;
break;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 858adfc436164..36cc31e614f21 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -27,9 +27,7 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
return failure();
// `TilingInterface` currently only supports strides being 1.
- if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
- return !isConstantIntValue(ofr, 1);
- }))
+ if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
return failure();
FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
@@ -49,9 +47,7 @@ FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
return failure();
// `TilingInterface` currently only supports strides being 1.
- if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
- return !isConstantIntValue(ofr, 1);
- }))
+ if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
return failure();
FailureOr<TilingResult> tiledResult =
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index fcb736aa031f3..51b51d8aa32e4 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -18,10 +18,13 @@ namespace mlir {
bool isZeroIndex(OpFoldResult v) {
if (!v)
return false;
- std::optional<int64_t> constint = getConstantIntValue(v);
- if (!constint)
+ return isConstantIntValue(v, 0);
+}
+
+bool isOneIndex(OpFoldResult v) {
+ if (!v)
return false;
- return *constint == 0;
+ return isConstantIntValue(v, 1);
}
std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9cef003fa365..4e5c60671b976 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -141,7 +141,7 @@ struct LinearizeVectorExtractStridedSlice final
ArrayAttr offsets = extractOp.getOffsets();
ArrayAttr sizes = extractOp.getSizes();
ArrayAttr strides = extractOp.getStrides();
- if (!isConstantIntValue(strides[0], 1))
+ if (!isOneIndex(strides[0]))
return rewriter.notifyMatchFailure(
extractOp, "Strided slice with stride != 1 is not supported.");
Value srcVector = adaptor.getVector();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b94c5fce64f83..83dc34e4b4139 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1118,7 +1118,7 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
ArithIndexingBuilder idxBuilderf(rewriter, loc);
for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
OpFoldResult pos = extractPos[i - rankOffset];
- if (isConstantIntValue(pos, 0))
+ if (isZeroIndex(pos))
continue;
Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);
|
el-ev
reviewed
May 10, 2025
mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
Outdated
Show resolved
Hide resolved
The revision adds isOneIndex helper, and simplifies the existing code with the two methods. It removes some lambda, which makes code cleaner. Signed-off-by: hanhanW <[email protected]>
Signed-off-by: hanhanW <[email protected]>
Signed-off-by: hanhanW <[email protected]>
cdcb0af to
1e4b65e
Compare
matthias-springer
approved these changes
May 17, 2025
Signed-off-by: hanhanW <[email protected]>
rolfmorel
added a commit
to libxsmm/tpp-mlir
that referenced
this pull request
Jun 12, 2025
* llvm/llvm-project#139340 ``` sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp ``` * llvm/llvm-project#141466 & llvm/llvm-project#141019 * Add `BufferizationState &state` to `bufferize` and `getBuffer` * llvm/llvm-project#143159 & llvm/llvm-project#142683 & llvm/llvm-project#143779 * Updates to `transform.apply_registered_pass` and its Python-bindings * llvm/llvm-project#143217 * `tilingResult->mergeResult.replacements` -> `tilingResult->replacements` * llvm/llvm-project#140559 & llvm/llvm-project#143871 * Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s & fix which enables conversion again.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
mlir:linalg
mlir:memref
mlir:scf
mlir:sparse
Sparse compiler in MLIR
mlir:spirv
mlir:tensor
mlir:vector
mlir:vectorops
mlir
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The revision adds isOneInteger helper, and simplifies the existing code with the two methods. It removes some lambda, which makes code cleaner.
For downstream users, you can update the code with the below script.