From c40ab2c51367714ec2527c2c10577dca342c2b20 Mon Sep 17 00:00:00 2001 From: hanhanW Date: Fri, 9 May 2025 17:18:14 -0700 Subject: [PATCH 1/3] [mlir][NFC] Simplify constant checks with isZeroIndex and isOneIndex. The revision adds isOneIndex helper, and simplifies the existing code with the two methods. It removes some lambda, which makes code cleaner. Signed-off-by: hanhanW --- .../mlir/Dialect/Utils/StaticValueUtils.h | 4 ++++ .../MemRefToSPIRV/MemRefToSPIRV.cpp | 2 +- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 3 +-- .../TransformOps/LinalgTransformOps.cpp | 5 +---- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 9 +++++---- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 2 +- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 10 +++------- .../Transforms/ExtractAddressComputations.cpp | 5 +---- .../SCF/Transforms/TileUsingInterface.cpp | 20 ++++++++----------- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 6 +++--- .../Transforms/SparseVectorization.cpp | 2 +- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 3 +-- .../Tensor/IR/TensorTilingInterfaceImpl.cpp | 6 +++--- .../Tensor/Transforms/ReshapePatterns.cpp | 4 ++-- .../SwapExtractSliceWithProducerPatterns.cpp | 8 ++------ mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 9 ++++++--- .../Vector/Transforms/VectorLinearize.cpp | 3 +-- .../Vector/Transforms/VectorTransforms.cpp | 2 +- 18 files changed, 45 insertions(+), 58 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index 2a3a2defb810d..ea1a2384f8cba 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -28,6 +28,10 @@ namespace mlir { /// with attribute with value `0`. bool isZeroIndex(OpFoldResult v); +/// Return true if `v` is an IntegerAttr with value `1` of a ConstantIndexOp +/// with attribute with value `1`. +bool isOneIndex(OpFoldResult v); + /// Represents a range (offset, size, and stride) where each element of the /// triple may be dynamic or static. struct Range { diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 04bc62262c3d8..c9e7ae6f8bdb5 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -897,7 +897,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite( OpFoldResult offset = getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter) .front(); - if (isConstantIntValue(offset, 0)) { + if (isZeroIndex(offset)) { rewriter.replaceOp(op, src); return success(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 96106cf7ae120..4fdeca47ed304 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -4488,8 +4488,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { // Return true if we have a zero-value tile. auto hasZeros = [&](ArrayRef tiles) { - return llvm::any_of( - tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); }); + return llvm::any_of(tiles, isZeroIndex); }; // Verify tiles. Do not allow zero tiles. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index a9370dc003830..d736fb141cb0c 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3401,10 +3401,7 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, SmallVector ubs = loop.getMixedUpperBound(); SmallVector steps = loop.getMixedStep(); - if (llvm::all_of( - lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) && - llvm::all_of( - steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) { + if (llvm::all_of(lbs, isZeroIndex) && llvm::all_of(steps, isOneIndex)) { return loop; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 7c2788f16a3b6..700be3ad35705 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinOps.h" @@ -376,13 +377,13 @@ static void calculateTileOffsetsAndSizes( SmallVector threadIds = forallOp.getInductionVars(); SmallVector nonZeroNumThreads = llvm::filter_to_vector( - numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); }); + numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); }); int64_t nLoops = loopRanges.size(); tiledOffsets.reserve(nLoops); tiledSizes.reserve(nLoops); for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) { bool overflow = loopIdx >= numThreads.size(); - bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0); + bool isZero = !overflow && isZeroIndex(numThreads[loopIdx]); // Degenerate case: take the whole domain. if (overflow || isZero) { tiledOffsets.push_back(loopRanges[loopIdx].offset); @@ -413,7 +414,7 @@ static void calculateTileOffsetsAndSizes( OpFoldResult residualTileSize = makeComposedFoldedAffineApply( b, loc, i + j * m - n, {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size}); - if (!isConstantIntValue(residualTileSize, 0)) { + if (!isZeroIndex(residualTileSize)) { OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply( b, loc, -i + m, {offsetPerThread, size}); tileSizePerThread = @@ -655,7 +656,7 @@ FailureOr linalg::tileReductionUsingForall( Operation *tiledOp = nullptr; SmallVector nonZeroNumThreads = llvm::filter_to_vector( - numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); }); + numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); }); SmallVector materializedNonZeroNumThreads = getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index e8d460020cf69..7485df2cd73b3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -732,7 +732,7 @@ struct PackOpTiling // iterated or inner dims are not tiled. Otherwise, it will generate a // sequence of non-trivial ops (for partial tiles). for (auto offset : offsets.take_back(numTiles)) - if (!isConstantIntValue(offset, 0)) + if (!isZeroIndex(offset)) return failure(); for (auto iter : diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index a0237c18cf2fe..1175c57694272 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1889,9 +1889,7 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) { // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets // are 0. if (auto prev = src.getDefiningOp()) - if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) { - return isConstantIntValue(val, 0); - })) + if (llvm::all_of(prev.getMixedOffsets(), isZeroIndex)) return prev.getSource(); return nullptr; @@ -3285,11 +3283,9 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) { auto srcSizes = srcSubview.getMixedSizes(); auto sizes = getMixedSizes(); auto offsets = getMixedOffsets(); - bool allOffsetsZero = llvm::all_of( - offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }); + bool allOffsetsZero = llvm::all_of(offsets, isZeroIndex); auto strides = getMixedStrides(); - bool allStridesOne = llvm::all_of( - strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); }); + bool allStridesOne = llvm::all_of(strides, isOneIndex); bool allSizesSame = llvm::equal(sizes, srcSizes); if (allOffsetsZero && allStridesOne && allSizesSame && resultMemrefType == sourceMemrefType) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp index b906c727604dc..5a08900921ee5 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp @@ -251,10 +251,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern { // to do. SmallVector indices = getAsOpFoldResult(loadStoreLikeOp.getIndices()); - if (std::all_of(indices.begin(), indices.end(), - [](const OpFoldResult &opFold) { - return isConstantIntValue(opFold, 0); - })) { + if (std::all_of(indices.begin(), indices.end(), isZeroIndex)) { return rewriter.notifyMatchFailure( loadStoreLikeOp, "no computation to extract: offsets are 0s"); } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 0cd7da5db9163..d7d42219bc7b6 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -133,7 +133,7 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, tileSizes.resize(numLoops, zero); for (auto [index, range, nt] : llvm::enumerate(iterationDomain, numThreads)) { - if (isConstantIntValue(nt, 0)) + if (isZeroIndex(nt)) continue; tileSizes[index] = affine::makeComposedFoldedAffineApply( @@ -265,7 +265,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, // Non-tiled cases, set the offset and size to the // `loopRange.offset/size`. - if (isConstantIntValue(nt, 0)) { + if (isZeroIndex(nt)) { offsets.push_back(loopRange.offset); sizes.push_back(loopRange.size); continue; @@ -280,7 +280,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, {loopRange.offset, nt, tileSize, loopRange.size}); OpFoldResult size = tileSize; - if (!isConstantIntValue(residualTileSize, 0)) { + if (!isZeroIndex(residualTileSize)) { OpFoldResult sizeMinusOffsetPerThread = affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0, {offset, loopRange.size}); @@ -316,7 +316,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, // Non-tiled cases, set the offset and size to the // `loopRange.offset/size`. - if (isConstantIntValue(tileSize, 0)) { + if (isZeroIndex(tileSize)) { offsets.push_back(loopRange.offset); sizes.push_back(loopRange.size); continue; @@ -341,7 +341,7 @@ getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef loopRanges, SmallVector lbs, ubs, steps; for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) { // No loop if the tile size is 0. - if (isConstantIntValue(tileSize, 0)) + if (isZeroIndex(tileSize)) continue; lbs.push_back(loopRange.offset); ubs.push_back(loopRange.size); @@ -495,7 +495,7 @@ static LogicalResult generateLoopNestUsingForallOp( // Prune the zero numthreads. SmallVector nonZeroNumThreads; for (auto nt : numThreads) { - if (isConstantIntValue(nt, 0)) + if (isZeroIndex(nt)) continue; nonZeroNumThreads.push_back(nt); } @@ -1290,9 +1290,7 @@ FailureOr> mlir::scf::yieldReplacementForFusedProducer( sliceSizes = sliceOp.getMixedSizes(); // expect all strides of sliceOp being 1 - if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { - return !isConstantIntValue(ofr, 1); - })) + if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex)) return failure(); unsigned sliceResultNumber = @@ -2114,9 +2112,7 @@ mlir::scf::tileAndFuseConsumerOfSlice( SmallVector strides = ossSliceOp.getMixedStrides(); // 9. Check all insert stride is 1. - if (llvm::any_of(strides, [](OpFoldResult stride) { - return !isConstantIntValue(stride, 1); - })) { + if (!llvm::all_of(strides, isOneIndex)) { return rewriter.notifyMatchFailure( candidateSliceOp, "containingOp's result yield with stride"); } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index d9550fe18dc02..f95e38fc75c8d 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -768,7 +768,7 @@ static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, // If an `affine.apply` operation is generated for denormalization, the use // of `origLb` in those ops must not be replaced. These arent not generated // when `origLb == 0` and `origStep == 1`. - if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) { + if (!isZeroIndex(origLb) || !isOneIndex(origStep)) { if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) { preservedUses.insert(preservedUse); } @@ -785,8 +785,8 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc, } Value denormalizedIv; SmallPtrSet preserve; - bool isStepOne = isConstantIntValue(origStep, 1); - bool isZeroBased = isConstantIntValue(origLb, 0); + bool isStepOne = isOneIndex(origStep); + bool isZeroBased = isZeroIndex(origLb); Value scaled = normalizedIv; if (!isStepOne) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index b2eca539194a8..649375b4c4037 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -614,7 +614,7 @@ struct ForOpRewriter : public OpRewritePattern { // Check for single block, unit-stride for-loop that is generated by // sparsifier, which means no data dependence analysis is required, // and its loop-body is very restricted in form. - if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) || + if (!op.getRegion().hasOneBlock() || !isOneIndex(op.getStep()) || !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) return failure(); // Analyze (!codegen) and rewrite (codegen) loop-body. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 6c32476d8656f..6c17ebbb85c81 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2840,8 +2840,7 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) { return getResult(); if (auto result = foldInsertAfterExtractSlice(*this)) return result; - if (llvm::any_of(getMixedSizes(), - [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); })) + if (llvm::any_of(getMixedSizes(), isZeroIndex)) return getDest(); return OpFoldResult(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index 7778a02dbeaf4..41407064cb6d7 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -135,9 +135,9 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, SmallVector newStrides(rank, b.getIndexAttr(1)); for (unsigned dim = 0; dim < rank; ++dim) { auto low = padOp.getMixedLowPad()[dim]; - bool hasLowPad = !isConstantIntValue(low, 0); + bool hasLowPad = !isZeroIndex(low); auto high = padOp.getMixedHighPad()[dim]; - bool hasHighPad = !isConstantIntValue(high, 0); + bool hasHighPad = !isZeroIndex(high); auto offset = offsets[dim]; auto length = sizes[dim]; // If the dim has no padding, we dont need to calculate new values for that @@ -208,7 +208,7 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, // Check if newLength is zero. In that case, no SubTensorOp should be // executed. - if (isConstantIntValue(newLength, 0)) { + if (isZeroIndex(newLength)) { hasZeroLen = true; } else if (!hasZeroLen) { Value check = b.create( diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index a3de7f9b44ae6..9978aac1ee80e 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -452,7 +452,7 @@ struct BubbleUpExpandShapeThroughExtractSlice std::function isZeroOffsetAndFullSize = [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) { - if (!isConstantIntValue(offset, 0)) + if (!isZeroIndex(offset)) return false; FailureOr maybeEqual = ValueBoundsConstraintSet::areEqual(sliceSize, size); @@ -476,7 +476,7 @@ struct BubbleUpExpandShapeThroughExtractSlice // Find the first expanded dim after the first dim with non-unit extracted // size. for (; i < e; ++i) { - if (!isConstantIntValue(sizes[indices[i]], 1)) { + if (!isOneIndex(sizes[indices[i]])) { // +1 to skip the first non-unit size dim. i++; break; diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp index 858adfc436164..36cc31e614f21 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp @@ -27,9 +27,7 @@ FailureOr tensor::replaceExtractSliceWithTiledProducer( return failure(); // `TilingInterface` currently only supports strides being 1. - if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { - return !isConstantIntValue(ofr, 1); - })) + if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex)) return failure(); FailureOr tiledResult = producerOp.generateResultTileValue( @@ -49,9 +47,7 @@ FailureOr tensor::replaceInsertSliceWithTiledConsumer( return failure(); // `TilingInterface` currently only supports strides being 1. - if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { - return !isConstantIntValue(ofr, 1); - })) + if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex)) return failure(); FailureOr tiledResult = diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index fac836ebd7a36..2edfdf2508895 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -18,10 +18,13 @@ namespace mlir { bool isZeroIndex(OpFoldResult v) { if (!v) return false; - std::optional constint = getConstantIntValue(v); - if (!constint) + return isConstantIntValue(v, 0); +} + +bool isOneIndex(OpFoldResult v) { + if (!v) return false; - return *constint == 0; + return isConstantIntValue(v, 1); } std::tuple, SmallVector, diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 060ce7d1d6643..dc87424df3854 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -116,8 +116,7 @@ static bool stridesAllOne(TOp op) { std::is_same_v, "expected vector.extract_strided_slice or vector.insert_strided_slice"); ArrayAttr strides = op.getStrides(); - return llvm::all_of( - strides, [](auto stride) { return isConstantIntValue(stride, 1); }); + return llvm::all_of(strides, isOneIndex); } /// Convert an array of attributes into a vector of integers, if possible. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index c635be6e83b6a..ca15d410efc7a 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1118,7 +1118,7 @@ class ExtractOpFromLoad final : public OpRewritePattern { ArithIndexingBuilder idxBuilderf(rewriter, loc); for (auto i : llvm::seq(rankOffset, indices.size() - finalRank)) { OpFoldResult pos = extractPos[i - rankOffset]; - if (isConstantIntValue(pos, 0)) + if (isZeroIndex(pos)) continue; Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos); From 6f5eb8b28bcae88ce22b0c61ada71c2306a3c410 Mon Sep 17 00:00:00 2001 From: hanhanW Date: Fri, 16 May 2025 15:50:15 -0700 Subject: [PATCH 2/3] Update comments and implementation. Signed-off-by: hanhanW --- mlir/include/mlir/Dialect/Utils/StaticValueUtils.h | 6 ++---- mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 12 ++---------- 2 files changed, 4 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index ea1a2384f8cba..c64dd88b8f52d 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -24,12 +24,10 @@ namespace mlir { -/// Return true if `v` is an IntegerAttr with value `0` of a ConstantIndexOp -/// with attribute with value `0`. +/// Return true if `v` is an IntegerAttr with value `0`. bool isZeroIndex(OpFoldResult v); -/// Return true if `v` is an IntegerAttr with value `1` of a ConstantIndexOp -/// with attribute with value `1`. +/// Return true if `v` is an IntegerAttr with value `1`. bool isOneIndex(OpFoldResult v); /// Represents a range (offset, size, and stride) where each element of the diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 2edfdf2508895..3ad3a43fbed0e 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -15,17 +15,9 @@ namespace mlir { -bool isZeroIndex(OpFoldResult v) { - if (!v) - return false; - return isConstantIntValue(v, 0); -} +bool isZeroIndex(OpFoldResult v) { return isConstantIntValue(v, 0); } -bool isOneIndex(OpFoldResult v) { - if (!v) - return false; - return isConstantIntValue(v, 1); -} +bool isOneIndex(OpFoldResult v) { return isConstantIntValue(v, 1); } std::tuple, SmallVector, SmallVector> From 1e4b65eb879c78879647f3aafbf6e7bd71a7cbe4 Mon Sep 17 00:00:00 2001 From: hanhanW Date: Fri, 16 May 2025 15:51:54 -0700 Subject: [PATCH 3/3] Rename to isZeroInteger and isOneInteger. Signed-off-by: hanhanW --- .../mlir/Dialect/Utils/StaticValueUtils.h | 4 ++-- .../MemRefToSPIRV/MemRefToSPIRV.cpp | 2 +- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 2 +- .../TransformOps/LinalgTransformOps.cpp | 2 +- .../Transforms/ConvertToDestinationStyle.cpp | 4 ++-- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 8 ++++---- .../Linalg/Transforms/TilingInterfaceImpl.cpp | 4 ++-- mlir/lib/Dialect/Linalg/Utils/Utils.cpp | 8 ++++---- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 6 +++--- .../Transforms/ExtractAddressComputations.cpp | 2 +- .../SCF/Transforms/TileUsingInterface.cpp | 20 +++++++++---------- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 6 +++--- .../Transforms/SparseVectorization.cpp | 2 +- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 2 +- .../Tensor/IR/TensorTilingInterfaceImpl.cpp | 6 +++--- .../BufferizableOpInterfaceImpl.cpp | 2 +- .../Tensor/Transforms/ReshapePatterns.cpp | 4 ++-- .../SwapExtractSliceWithProducerPatterns.cpp | 4 ++-- mlir/lib/Dialect/Utils/StaticValueUtils.cpp | 4 ++-- .../Vector/Transforms/VectorLinearize.cpp | 2 +- .../Transforms/VectorTransferOpTransforms.cpp | 2 +- .../Vector/Transforms/VectorTransforms.cpp | 2 +- 22 files changed, 49 insertions(+), 49 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index c64dd88b8f52d..b37fb55b67931 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -25,10 +25,10 @@ namespace mlir { /// Return true if `v` is an IntegerAttr with value `0`. -bool isZeroIndex(OpFoldResult v); +bool isZeroInteger(OpFoldResult v); /// Return true if `v` is an IntegerAttr with value `1`. -bool isOneIndex(OpFoldResult v); +bool isOneInteger(OpFoldResult v); /// Represents a range (offset, size, and stride) where each element of the /// triple may be dynamic or static. diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index c9e7ae6f8bdb5..fdf799a20efdd 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -897,7 +897,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite( OpFoldResult offset = getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter) .front(); - if (isZeroIndex(offset)) { + if (isZeroInteger(offset)) { rewriter.replaceOp(op, src); return success(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 4fdeca47ed304..b7f78607e6241 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -4488,7 +4488,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { // Return true if we have a zero-value tile. auto hasZeros = [&](ArrayRef tiles) { - return llvm::any_of(tiles, isZeroIndex); + return llvm::any_of(tiles, isZeroInteger); }; // Verify tiles. Do not allow zero tiles. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index d736fb141cb0c..1c3b621828315 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3401,7 +3401,7 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, SmallVector ubs = loop.getMixedUpperBound(); SmallVector steps = loop.getMixedStep(); - if (llvm::all_of(lbs, isZeroIndex) && llvm::all_of(steps, isOneIndex)) { + if (llvm::all_of(lbs, isZeroInteger) && llvm::all_of(steps, isOneInteger)) { return loop; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index b1340be04e011..a62510deefc4a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -441,8 +441,8 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, // If the `padOp` has a nofold attribute and all paddings are known to be 0, // explicitly insert a `linalg.copy`. if (padOp.getNofoldAttr() && - llvm::all_of(padOp.getMixedLowPad(), isZeroIndex) && - llvm::all_of(padOp.getMixedHighPad(), isZeroIndex)) { + llvm::all_of(padOp.getMixedLowPad(), isZeroInteger) && + llvm::all_of(padOp.getMixedHighPad(), isZeroInteger)) { using bufferization::AllocTensorOp; Value allocated = rewriter.create(loc, resultType, dynamicSizes); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 700be3ad35705..4162aa0b71e6d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -377,13 +377,13 @@ static void calculateTileOffsetsAndSizes( SmallVector threadIds = forallOp.getInductionVars(); SmallVector nonZeroNumThreads = llvm::filter_to_vector( - numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); }); + numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); }); int64_t nLoops = loopRanges.size(); tiledOffsets.reserve(nLoops); tiledSizes.reserve(nLoops); for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) { bool overflow = loopIdx >= numThreads.size(); - bool isZero = !overflow && isZeroIndex(numThreads[loopIdx]); + bool isZero = !overflow && isZeroInteger(numThreads[loopIdx]); // Degenerate case: take the whole domain. if (overflow || isZero) { tiledOffsets.push_back(loopRanges[loopIdx].offset); @@ -414,7 +414,7 @@ static void calculateTileOffsetsAndSizes( OpFoldResult residualTileSize = makeComposedFoldedAffineApply( b, loc, i + j * m - n, {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size}); - if (!isZeroIndex(residualTileSize)) { + if (!isZeroInteger(residualTileSize)) { OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply( b, loc, -i + m, {offsetPerThread, size}); tileSizePerThread = @@ -656,7 +656,7 @@ FailureOr linalg::tileReductionUsingForall( Operation *tiledOp = nullptr; SmallVector nonZeroNumThreads = llvm::filter_to_vector( - numThreads, [](OpFoldResult ofr) { return !isZeroIndex(ofr); }); + numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); }); SmallVector materializedNonZeroNumThreads = getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 7485df2cd73b3..7c14cc16437fe 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -369,7 +369,7 @@ struct LinalgOpPartialReductionInterface SmallVector tiledShape; for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) { - if (isZeroIndex(tileSize)) { + if (isZeroInteger(tileSize)) { tiledShape.push_back(dimSize); } else { tiledShape.push_back(tileSize); @@ -732,7 +732,7 @@ struct PackOpTiling // iterated or inner dims are not tiled. Otherwise, it will generate a // sequence of non-trivial ops (for partial tiles). for (auto offset : offsets.take_back(numTiles)) - if (!isZeroIndex(offset)) + if (!isZeroInteger(offset)) return failure(); for (auto iter : diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index d3d301ca093b1..bae06c003fd97 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -59,7 +59,7 @@ struct TileCheck : public AffineExprVisitor { TileCheck(ArrayRef tileSizes) : tileSizes(tileSizes) {} void visitDimExpr(AffineDimExpr expr) { - isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]); + isTiled |= !isZeroInteger(tileSizes[expr.getPosition()]); } void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { visit(expr.getLHS()); @@ -741,7 +741,7 @@ SmallVector computeTileOffsets(OpBuilder &b, Location loc, SmallVector offsets; for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n"); - bool isTiled = !isZeroIndex(tileSizes[idx]); + bool isTiled = !isZeroInteger(tileSizes[idx]); offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0)); LLVM_DEBUG(llvm::dbgs() << "computeTileOffsets: " << offsets.back() << "\n"); @@ -754,7 +754,7 @@ SmallVector computeTileSizes(OpBuilder &b, Location loc, ArrayRef sizeBounds) { SmallVector sizes; for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { - bool isTiled = !isZeroIndex(tileSizes[idx]); + bool isTiled = !isZeroInteger(tileSizes[idx]); // Before composing, we need to make range a closed interval. OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx]; AffineExpr d0 = getAffineDimExpr(0, b.getContext()); @@ -810,7 +810,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, bool omitPartialTileCheck) { assert(ivs.size() == static_cast(llvm::count_if( llvm::make_range(tileSizes.begin(), tileSizes.end()), - [](OpFoldResult v) { return !isZeroIndex(v); })) && + [](OpFoldResult v) { return !isZeroInteger(v); })) && "expected as many ivs as non-zero sizes"); // Construct (potentially temporary) mins and maxes on which to apply maps diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 1175c57694272..95c8b72643735 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1889,7 +1889,7 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) { // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets // are 0. if (auto prev = src.getDefiningOp()) - if (llvm::all_of(prev.getMixedOffsets(), isZeroIndex)) + if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger)) return prev.getSource(); return nullptr; @@ -3283,9 +3283,9 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) { auto srcSizes = srcSubview.getMixedSizes(); auto sizes = getMixedSizes(); auto offsets = getMixedOffsets(); - bool allOffsetsZero = llvm::all_of(offsets, isZeroIndex); + bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger); auto strides = getMixedStrides(); - bool allStridesOne = llvm::all_of(strides, isOneIndex); + bool allStridesOne = llvm::all_of(strides, isOneInteger); bool allSizesSame = llvm::equal(sizes, srcSizes); if (allOffsetsZero && allStridesOne && allSizesSame && resultMemrefType == sourceMemrefType) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp index 5a08900921ee5..9e942f10b1f16 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp @@ -251,7 +251,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern { // to do. SmallVector indices = getAsOpFoldResult(loadStoreLikeOp.getIndices()); - if (std::all_of(indices.begin(), indices.end(), isZeroIndex)) { + if (std::all_of(indices.begin(), indices.end(), isZeroInteger)) { return rewriter.notifyMatchFailure( loadStoreLikeOp, "no computation to extract: offsets are 0s"); } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index d7d42219bc7b6..719e2c6fa459e 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -133,7 +133,7 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, tileSizes.resize(numLoops, zero); for (auto [index, range, nt] : llvm::enumerate(iterationDomain, numThreads)) { - if (isZeroIndex(nt)) + if (isZeroInteger(nt)) continue; tileSizes[index] = affine::makeComposedFoldedAffineApply( @@ -265,7 +265,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, // Non-tiled cases, set the offset and size to the // `loopRange.offset/size`. - if (isZeroIndex(nt)) { + if (isZeroInteger(nt)) { offsets.push_back(loopRange.offset); sizes.push_back(loopRange.size); continue; @@ -280,7 +280,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, {loopRange.offset, nt, tileSize, loopRange.size}); OpFoldResult size = tileSize; - if (!isZeroIndex(residualTileSize)) { + if (!isZeroInteger(residualTileSize)) { OpFoldResult sizeMinusOffsetPerThread = affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0, {offset, loopRange.size}); @@ -316,7 +316,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, // Non-tiled cases, set the offset and size to the // `loopRange.offset/size`. - if (isZeroIndex(tileSize)) { + if (isZeroInteger(tileSize)) { offsets.push_back(loopRange.offset); sizes.push_back(loopRange.size); continue; @@ -341,7 +341,7 @@ getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef loopRanges, SmallVector lbs, ubs, steps; for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) { // No loop if the tile size is 0. - if (isZeroIndex(tileSize)) + if (isZeroInteger(tileSize)) continue; lbs.push_back(loopRange.offset); ubs.push_back(loopRange.size); @@ -495,7 +495,7 @@ static LogicalResult generateLoopNestUsingForallOp( // Prune the zero numthreads. SmallVector nonZeroNumThreads; for (auto nt : numThreads) { - if (isZeroIndex(nt)) + if (isZeroInteger(nt)) continue; nonZeroNumThreads.push_back(nt); } @@ -551,7 +551,7 @@ static LogicalResult generateLoopNest( YieldTiledValuesFn tiledBodyFn, SmallVector &loops) { // If the tile sizes are all zero, no loops are generated. Just call the // callback function to handle untiled case. - if (llvm::all_of(tileSizes, isZeroIndex)) { + if (llvm::all_of(tileSizes, isZeroInteger)) { SmallVector tiledResults; SmallVector> resultOffsets, resultSizes; return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, @@ -999,7 +999,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, // 5b. Early return cloned op if tiling is not happening. We can not // return the original op because it could lead to `rewriter.replaceOp(op, // op->getResults())` and users would get crash. - if (llvm::all_of(tileSizes, isZeroIndex)) { + if (llvm::all_of(tileSizes, isZeroInteger)) { tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); tilingResult = TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(), @@ -1290,7 +1290,7 @@ FailureOr> mlir::scf::yieldReplacementForFusedProducer( sliceSizes = sliceOp.getMixedSizes(); // expect all strides of sliceOp being 1 - if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex)) + if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger)) return failure(); unsigned sliceResultNumber = @@ -2112,7 +2112,7 @@ mlir::scf::tileAndFuseConsumerOfSlice( SmallVector strides = ossSliceOp.getMixedStrides(); // 9. Check all insert stride is 1. - if (!llvm::all_of(strides, isOneIndex)) { + if (!llvm::all_of(strides, isOneInteger)) { return rewriter.notifyMatchFailure( candidateSliceOp, "containingOp's result yield with stride"); } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index f95e38fc75c8d..8ab5bdc0c5dc5 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -768,7 +768,7 @@ static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, // If an `affine.apply` operation is generated for denormalization, the use // of `origLb` in those ops must not be replaced. These arent not generated // when `origLb == 0` and `origStep == 1`. - if (!isZeroIndex(origLb) || !isOneIndex(origStep)) { + if (!isZeroInteger(origLb) || !isOneInteger(origStep)) { if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) { preservedUses.insert(preservedUse); } @@ -785,8 +785,8 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc, } Value denormalizedIv; SmallPtrSet preserve; - bool isStepOne = isOneIndex(origStep); - bool isZeroBased = isZeroIndex(origLb); + bool isStepOne = isOneInteger(origStep); + bool isZeroBased = isZeroInteger(origLb); Value scaled = normalizedIv; if (!isStepOne) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index 649375b4c4037..3d963dea2f572 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -614,7 +614,7 @@ struct ForOpRewriter : public OpRewritePattern { // Check for single block, unit-stride for-loop that is generated by // sparsifier, which means no data dependence analysis is required, // and its loop-body is very restricted in form. - if (!op.getRegion().hasOneBlock() || !isOneIndex(op.getStep()) || + if (!op.getRegion().hasOneBlock() || !isOneInteger(op.getStep()) || !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) return failure(); // Analyze (!codegen) and rewrite (codegen) loop-body. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 6c17ebbb85c81..8db563fb7a25f 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2840,7 +2840,7 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) { return getResult(); if (auto result = foldInsertAfterExtractSlice(*this)) return result; - if (llvm::any_of(getMixedSizes(), isZeroIndex)) + if (llvm::any_of(getMixedSizes(), isZeroInteger)) return getDest(); return OpFoldResult(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index 41407064cb6d7..92540bd56ecbc 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -135,9 +135,9 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, SmallVector newStrides(rank, b.getIndexAttr(1)); for (unsigned dim = 0; dim < rank; ++dim) { auto low = padOp.getMixedLowPad()[dim]; - bool hasLowPad = !isZeroIndex(low); + bool hasLowPad = !isZeroInteger(low); auto high = padOp.getMixedHighPad()[dim]; - bool hasHighPad = !isZeroIndex(high); + bool hasHighPad = !isZeroInteger(high); auto offset = offsets[dim]; auto length = sizes[dim]; // If the dim has no padding, we dont need to calculate new values for that @@ -208,7 +208,7 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, // Check if newLength is zero. In that case, no SubTensorOp should be // executed. - if (isZeroIndex(newLength)) { + if (isZeroInteger(newLength)) { hasZeroLen = true; } else if (!hasZeroLen) { Value check = b.create( diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index c0e697292d2a0..81a2480940742 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -646,7 +646,7 @@ static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp, // Dest is not read if it is entirely overwritten. E.g.: // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32> bool allOffsetsZero = - llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex); + llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroInteger); RankedTensorType destType = insertSliceOp.getDestType(); bool sizesMatchDestSizes = areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape()); diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index 9978aac1ee80e..2b229d60c691b 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -452,7 +452,7 @@ struct BubbleUpExpandShapeThroughExtractSlice std::function isZeroOffsetAndFullSize = [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) { - if (!isZeroIndex(offset)) + if (!isZeroInteger(offset)) return false; FailureOr maybeEqual = ValueBoundsConstraintSet::areEqual(sliceSize, size); @@ -476,7 +476,7 @@ struct BubbleUpExpandShapeThroughExtractSlice // Find the first expanded dim after the first dim with non-unit extracted // size. for (; i < e; ++i) { - if (!isOneIndex(sizes[indices[i]])) { + if (!isOneInteger(sizes[indices[i]])) { // +1 to skip the first non-unit size dim. i++; break; diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp index 36cc31e614f21..6f33f9b55ceb6 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp @@ -27,7 +27,7 @@ FailureOr tensor::replaceExtractSliceWithTiledProducer( return failure(); // `TilingInterface` currently only supports strides being 1. - if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex)) + if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger)) return failure(); FailureOr tiledResult = producerOp.generateResultTileValue( @@ -47,7 +47,7 @@ FailureOr tensor::replaceInsertSliceWithTiledConsumer( return failure(); // `TilingInterface` currently only supports strides being 1. - if (!llvm::all_of(sliceOp.getMixedStrides(), isOneIndex)) + if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger)) return failure(); FailureOr tiledResult = diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index 3ad3a43fbed0e..29f7bd6857c27 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -15,9 +15,9 @@ namespace mlir { -bool isZeroIndex(OpFoldResult v) { return isConstantIntValue(v, 0); } +bool isZeroInteger(OpFoldResult v) { return isConstantIntValue(v, 0); } -bool isOneIndex(OpFoldResult v) { return isConstantIntValue(v, 1); } +bool isOneInteger(OpFoldResult v) { return isConstantIntValue(v, 1); } std::tuple, SmallVector, SmallVector> diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index dc87424df3854..678a88627ca82 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -116,7 +116,7 @@ static bool stridesAllOne(TOp op) { std::is_same_v, "expected vector.extract_strided_slice or vector.insert_strided_slice"); ArrayAttr strides = op.getStrides(); - return llvm::all_of(strides, isOneIndex); + return llvm::all_of(strides, isOneInteger); } /// Convert an array of attributes into a vector of integers, if possible. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index d4d07c7eadc77..7dbb7a334fe62 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -538,7 +538,7 @@ static SmallVector getCollapsedIndices(RewriterBase &rewriter, indices.begin(), indices.begin() + firstDimToCollapse); SmallVector indicesToCollapse(indices.begin() + firstDimToCollapse, indices.end()); - if (llvm::all_of(indicesToCollapse, isZeroIndex)) { + if (llvm::all_of(indicesToCollapse, isZeroInteger)) { indicesAfterCollapsing.push_back(indicesToCollapse[0]); return indicesAfterCollapsing; } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index ca15d410efc7a..71c557f7eda06 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1118,7 +1118,7 @@ class ExtractOpFromLoad final : public OpRewritePattern { ArithIndexingBuilder idxBuilderf(rewriter, loc); for (auto i : llvm::seq(rankOffset, indices.size() - finalRank)) { OpFoldResult pos = extractPos[i - rankOffset]; - if (isZeroIndex(pos)) + if (isZeroInteger(pos)) continue; Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);