Skip to content

Conversation

@hanhanW
Copy link
Contributor

@hanhanW hanhanW commented May 10, 2025

The revision adds isOneInteger helper, and simplifies the existing code with the two methods. It removes some lambda, which makes code cleaner.

For downstream users, you can update the code with the below script.

sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp

@llvmbot
Copy link
Member

llvmbot commented May 10, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-scf

Author: Han-Chung Wang (hanhanW)

Changes

The revision adds isOneIndex helper, and simplifies the existing code with the two methods. It removes some lambda, which makes code cleaner.


Full diff: https://github.com/llvm/llvm-project/pull/139340.diff

18 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Utils/StaticValueUtils.h (+4)
  • (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+1-4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (+5-4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+1-1)
  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+3-7)
  • (modified) mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp (+1-4)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+8-12)
  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+3-3)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+1-2)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp (+2-6)
  • (modified) mlir/lib/Dialect/Utils/StaticValueUtils.cpp (+6-3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
index 2a3a2defb810d..ea1a2384f8cba 100644
--- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
+++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h
@@ -28,6 +28,10 @@ namespace mlir {
 /// with attribute with value `0`.
 bool isZeroIndex(OpFoldResult v);
 
+/// Return true if `v` is an IntegerAttr with value `1` of a ConstantIndexOp
+/// with attribute with value `1`.
+bool isOneIndex(OpFoldResult v);
+
 /// Represents a range (offset, size, and stride) where each element of the
 /// triple may be dynamic or static.
 struct Range {
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 04bc62262c3d8..c9e7ae6f8bdb5 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -897,7 +897,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
   OpFoldResult offset =
       getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter)
           .front();
-  if (isConstantIntValue(offset, 0)) {
+  if (isZeroIndex(offset)) {
     rewriter.replaceOp(op, src);
     return success();
   }
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index fce0751430305..a6b1e21cd3b53 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -4426,8 +4426,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
 
   // Return true if we have a zero-value tile.
   auto hasZeros = [&](ArrayRef<OpFoldResult> tiles) {
-    return llvm::any_of(
-        tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
+    return llvm::any_of(tiles, isZeroIndex);
   };
 
   // Verify tiles. Do not allow zero tiles.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f6ca109b84f9e..25b0635220f3b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -3315,10 +3315,7 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
   SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
   SmallVector<OpFoldResult> steps = loop.getMixedStep();
 
-  if (llvm::all_of(
-          lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
-      llvm::all_of(
-          steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
+  if (llvm::all_of(lbs, isZeroIndex) && llvm::all_of(steps, isOneIndex)) {
     return loop;
   }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 7c2788f16a3b6..700be3ad35705 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -23,6 +23,7 @@
 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -376,13 +377,13 @@ static void calculateTileOffsetsAndSizes(
 
   SmallVector<Value> threadIds = forallOp.getInductionVars();
   SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
-      numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
+      numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); });
   int64_t nLoops = loopRanges.size();
   tiledOffsets.reserve(nLoops);
   tiledSizes.reserve(nLoops);
   for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {
     bool overflow = loopIdx >= numThreads.size();
-    bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);
+    bool isZero = !overflow && isZeroIndex(numThreads[loopIdx]);
     // Degenerate case: take the whole domain.
     if (overflow || isZero) {
       tiledOffsets.push_back(loopRanges[loopIdx].offset);
@@ -413,7 +414,7 @@ static void calculateTileOffsetsAndSizes(
     OpFoldResult residualTileSize = makeComposedFoldedAffineApply(
         b, loc, i + j * m - n,
         {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});
-    if (!isConstantIntValue(residualTileSize, 0)) {
+    if (!isZeroIndex(residualTileSize)) {
       OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(
           b, loc, -i + m, {offsetPerThread, size});
       tileSizePerThread =
@@ -655,7 +656,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   Operation *tiledOp = nullptr;
 
   SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(
-      numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });
+      numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); });
   SmallVector<Value> materializedNonZeroNumThreads =
       getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 0cc840403a020..faae77a6eecb3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -732,7 +732,7 @@ struct PackOpTiling
     // iterated or inner dims are not tiled. Otherwise, it will generate a
     // sequence of non-trivial ops (for partial tiles).
     for (auto offset : offsets.take_back(numTiles))
-      if (!isConstantIntValue(offset, 0))
+      if (!isZeroIndex(offset))
         return failure();
 
     for (auto iter :
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a0237c18cf2fe..1175c57694272 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -1889,9 +1889,7 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
     // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
     // are 0.
     if (auto prev = src.getDefiningOp<SubViewOp>())
-      if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) {
-            return isConstantIntValue(val, 0);
-          }))
+      if (llvm::all_of(prev.getMixedOffsets(), isZeroIndex))
         return prev.getSource();
 
     return nullptr;
@@ -3285,11 +3283,9 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
     auto srcSizes = srcSubview.getMixedSizes();
     auto sizes = getMixedSizes();
     auto offsets = getMixedOffsets();
-    bool allOffsetsZero = llvm::all_of(
-        offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); });
+    bool allOffsetsZero = llvm::all_of(offsets, isZeroIndex);
     auto strides = getMixedStrides();
-    bool allStridesOne = llvm::all_of(
-        strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); });
+    bool allStridesOne = llvm::all_of(strides, isOneIndex);
     bool allSizesSame = llvm::equal(sizes, srcSizes);
     if (allOffsetsZero && allStridesOne && allSizesSame &&
         resultMemrefType == sourceMemrefType)
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
index 05ba6a3f38708..e28f7d3e4924a 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp
@@ -251,10 +251,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> {
     // to do.
     SmallVector<OpFoldResult> indices =
         getAsOpFoldResult(loadStoreLikeOp.getIndices());
-    if (std::all_of(indices.begin(), indices.end(),
-                    [](const OpFoldResult &opFold) {
-                      return isConstantIntValue(opFold, 0);
-                    })) {
+    if (std::all_of(indices.begin(), indices.end(), isZeroIndex)) {
       return rewriter.notifyMatchFailure(
           loadStoreLikeOp, "no computation to extract: offsets are 0s");
     }
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 0cd7da5db9163..d7d42219bc7b6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -133,7 +133,7 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
     tileSizes.resize(numLoops, zero);
     for (auto [index, range, nt] :
          llvm::enumerate(iterationDomain, numThreads)) {
-      if (isConstantIntValue(nt, 0))
+      if (isZeroIndex(nt))
         continue;
 
       tileSizes[index] = affine::makeComposedFoldedAffineApply(
@@ -265,7 +265,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
 
       // Non-tiled cases, set the offset and size to the
       // `loopRange.offset/size`.
-      if (isConstantIntValue(nt, 0)) {
+      if (isZeroIndex(nt)) {
         offsets.push_back(loopRange.offset);
         sizes.push_back(loopRange.size);
         continue;
@@ -280,7 +280,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
           {loopRange.offset, nt, tileSize, loopRange.size});
 
       OpFoldResult size = tileSize;
-      if (!isConstantIntValue(residualTileSize, 0)) {
+      if (!isZeroIndex(residualTileSize)) {
         OpFoldResult sizeMinusOffsetPerThread =
             affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
                                                   {offset, loopRange.size});
@@ -316,7 +316,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
 
       // Non-tiled cases, set the offset and size to the
       // `loopRange.offset/size`.
-      if (isConstantIntValue(tileSize, 0)) {
+      if (isZeroIndex(tileSize)) {
         offsets.push_back(loopRange.offset);
         sizes.push_back(loopRange.size);
         continue;
@@ -341,7 +341,7 @@ getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
   SmallVector<OpFoldResult> lbs, ubs, steps;
   for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
     // No loop if the tile size is 0.
-    if (isConstantIntValue(tileSize, 0))
+    if (isZeroIndex(tileSize))
       continue;
     lbs.push_back(loopRange.offset);
     ubs.push_back(loopRange.size);
@@ -495,7 +495,7 @@ static LogicalResult generateLoopNestUsingForallOp(
     // Prune the zero numthreads.
     SmallVector<OpFoldResult> nonZeroNumThreads;
     for (auto nt : numThreads) {
-      if (isConstantIntValue(nt, 0))
+      if (isZeroIndex(nt))
         continue;
       nonZeroNumThreads.push_back(nt);
     }
@@ -1290,9 +1290,7 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
                               sliceSizes = sliceOp.getMixedSizes();
 
     // expect all strides of sliceOp being 1
-    if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
-          return !isConstantIntValue(ofr, 1);
-        }))
+    if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
       return failure();
 
     unsigned sliceResultNumber =
@@ -2114,9 +2112,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(
     SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides();
 
     // 9. Check all insert stride is 1.
-    if (llvm::any_of(strides, [](OpFoldResult stride) {
-          return !isConstantIntValue(stride, 1);
-        })) {
+    if (!llvm::all_of(strides, isOneIndex)) {
       return rewriter.notifyMatchFailure(
           candidateSliceOp, "containingOp's result yield with stride");
     }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index d9550fe18dc02..f95e38fc75c8d 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -768,7 +768,7 @@ static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter,
   // If an `affine.apply` operation is generated for denormalization, the use
   // of `origLb` in those ops must not be replaced. These arent not generated
   // when `origLb == 0` and `origStep == 1`.
-  if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) {
+  if (!isZeroIndex(origLb) || !isOneIndex(origStep)) {
     if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) {
       preservedUses.insert(preservedUse);
     }
@@ -785,8 +785,8 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
   }
   Value denormalizedIv;
   SmallPtrSet<Operation *, 2> preserve;
-  bool isStepOne = isConstantIntValue(origStep, 1);
-  bool isZeroBased = isConstantIntValue(origLb, 0);
+  bool isStepOne = isOneIndex(origStep);
+  bool isZeroBased = isZeroIndex(origLb);
 
   Value scaled = normalizedIv;
   if (!isStepOne) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index b2eca539194a8..649375b4c4037 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -614,7 +614,7 @@ struct ForOpRewriter : public OpRewritePattern<scf::ForOp> {
     // Check for single block, unit-stride for-loop that is generated by
     // sparsifier, which means no data dependence analysis is required,
     // and its loop-body is very restricted in form.
-    if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) ||
+    if (!op.getRegion().hasOneBlock() || !isOneIndex(op.getStep()) ||
         !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName()))
       return failure();
     // Analyze (!codegen) and rewrite (codegen) loop-body.
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 29da32cd1791c..717ea1d0d7618 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2738,8 +2738,7 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
     return getResult();
   if (auto result = foldInsertAfterExtractSlice(*this))
     return result;
-  if (llvm::any_of(getMixedSizes(),
-                   [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
+  if (llvm::any_of(getMixedSizes(), isZeroIndex))
     return getDest();
   return OpFoldResult();
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 7778a02dbeaf4..41407064cb6d7 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -135,9 +135,9 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
   SmallVector<OpFoldResult> newStrides(rank, b.getIndexAttr(1));
   for (unsigned dim = 0; dim < rank; ++dim) {
     auto low = padOp.getMixedLowPad()[dim];
-    bool hasLowPad = !isConstantIntValue(low, 0);
+    bool hasLowPad = !isZeroIndex(low);
     auto high = padOp.getMixedHighPad()[dim];
-    bool hasHighPad = !isConstantIntValue(high, 0);
+    bool hasHighPad = !isZeroIndex(high);
     auto offset = offsets[dim];
     auto length = sizes[dim];
     // If the dim has no padding, we dont need to calculate new values for that
@@ -208,7 +208,7 @@ FailureOr<TilingResult> tensor::bubbleUpPadSlice(OpBuilder &b,
 
     // Check if newLength is zero. In that case, no SubTensorOp should be
     // executed.
-    if (isConstantIntValue(newLength, 0)) {
+    if (isZeroIndex(newLength)) {
       hasZeroLen = true;
     } else if (!hasZeroLen) {
       Value check = b.create<arith::CmpIOp>(
diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
index a3de7f9b44ae6..9978aac1ee80e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp
@@ -452,7 +452,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
     std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)>
         isZeroOffsetAndFullSize =
             [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) {
-              if (!isConstantIntValue(offset, 0))
+              if (!isZeroIndex(offset))
                 return false;
               FailureOr<bool> maybeEqual =
                   ValueBoundsConstraintSet::areEqual(sliceSize, size);
@@ -476,7 +476,7 @@ struct BubbleUpExpandShapeThroughExtractSlice
       // Find the first expanded dim after the first dim with non-unit extracted
       // size.
       for (; i < e; ++i) {
-        if (!isConstantIntValue(sizes[indices[i]], 1)) {
+        if (!isOneIndex(sizes[indices[i]])) {
           // +1 to skip the first non-unit size dim.
           i++;
           break;
diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
index 858adfc436164..36cc31e614f21 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp
@@ -27,9 +27,7 @@ FailureOr<TilingResult> tensor::replaceExtractSliceWithTiledProducer(
     return failure();
 
   // `TilingInterface` currently only supports strides being 1.
-  if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
-        return !isConstantIntValue(ofr, 1);
-      }))
+  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
     return failure();
 
   FailureOr<TilingResult> tiledResult = producerOp.generateResultTileValue(
@@ -49,9 +47,7 @@ FailureOr<TilingResult> tensor::replaceInsertSliceWithTiledConsumer(
     return failure();
 
   // `TilingInterface` currently only supports strides being 1.
-  if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
-        return !isConstantIntValue(ofr, 1);
-      }))
+  if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex))
     return failure();
 
   FailureOr<TilingResult> tiledResult =
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index fcb736aa031f3..51b51d8aa32e4 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -18,10 +18,13 @@ namespace mlir {
 bool isZeroIndex(OpFoldResult v) {
   if (!v)
     return false;
-  std::optional<int64_t> constint = getConstantIntValue(v);
-  if (!constint)
+  return isConstantIntValue(v, 0);
+}
+
+bool isOneIndex(OpFoldResult v) {
+  if (!v)
     return false;
-  return *constint == 0;
+  return isConstantIntValue(v, 1);
 }
 
 std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9cef003fa365..4e5c60671b976 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -141,7 +141,7 @@ struct LinearizeVectorExtractStridedSlice final
     ArrayAttr offsets = extractOp.getOffsets();
     ArrayAttr sizes = extractOp.getSizes();
     ArrayAttr strides = extractOp.getStrides();
-    if (!isConstantIntValue(strides[0], 1))
+    if (!isOneIndex(strides[0]))
       return rewriter.notifyMatchFailure(
           extractOp, "Strided slice with stride != 1 is not supported.");
     Value srcVector = adaptor.getVector();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index b94c5fce64f83..83dc34e4b4139 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1118,7 +1118,7 @@ class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
     ArithIndexingBuilder idxBuilderf(rewriter, loc);
     for (auto i : llvm::seq<int64_t>(rankOffset, indices.size() - finalRank)) {
       OpFoldResult pos = extractPos[i - rankOffset];
-      if (isConstantIntValue(pos, 0))
+      if (isZeroIndex(pos))
         continue;
 
       Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);

hanhanW added 3 commits May 16, 2025 15:48
The revision adds isOneIndex helper, and simplifies the existing code
with the two methods. It removes some lambda, which makes code cleaner.

Signed-off-by: hanhanW <[email protected]>
@hanhanW hanhanW force-pushed the is-zero-one-index branch from cdcb0af to 1e4b65e Compare May 16, 2025 22:52
@hanhanW hanhanW changed the title [mlir][NFC] Simplify constant checks with isZeroIndex and isOneIndex. [mlir][NFC] Simplify constant checks with isOneInteger and renamed isZeroInteger. May 16, 2025
@hanhanW hanhanW requested a review from matthias-springer May 16, 2025 22:54
@hanhanW hanhanW merged commit c39915f into llvm:main May 20, 2025
9 of 10 checks passed
@hanhanW hanhanW deleted the is-zero-one-index branch May 20, 2025 21:53
rolfmorel added a commit to libxsmm/tpp-mlir that referenced this pull request Jun 12, 2025
* llvm/llvm-project#139340
```
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp
```

* llvm/llvm-project#141466 &
llvm/llvm-project#141019
  * Add `BufferizationState &state` to `bufferize` and `getBuffer` 

* llvm/llvm-project#143159 &
llvm/llvm-project#142683 &
llvm/llvm-project#143779
  * Updates to `transform.apply_registered_pass` and its Python-bindings

* llvm/llvm-project#143217
* `tilingResult->mergeResult.replacements` ->
`tilingResult->replacements`

* llvm/llvm-project#140559 &
llvm/llvm-project#143871
* Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s
& fix which enables conversion again.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants