diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h index 2a3a2defb810d..b37fb55b67931 100644 --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -24,9 +24,11 @@ namespace mlir { -/// Return true if `v` is an IntegerAttr with value `0` of a ConstantIndexOp -/// with attribute with value `0`. -bool isZeroIndex(OpFoldResult v); +/// Return true if `v` is an IntegerAttr with value `0`. +bool isZeroInteger(OpFoldResult v); + +/// Return true if `v` is an IntegerAttr with value `1`. +bool isOneInteger(OpFoldResult v); /// Represents a range (offset, size, and stride) where each element of the /// triple may be dynamic or static. diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 04bc62262c3d8..fdf799a20efdd 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -897,7 +897,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite( OpFoldResult offset = getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(), rewriter) .front(); - if (isConstantIntValue(offset, 0)) { + if (isZeroInteger(offset)) { rewriter.replaceOp(op, src); return success(); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 8a8bd5e232c40..5fc3ace5d6aab 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -4488,8 +4488,7 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { // Return true if we have a zero-value tile. auto hasZeros = [&](ArrayRef tiles) { - return llvm::any_of( - tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); }); + return llvm::any_of(tiles, isZeroInteger); }; // Verify tiles. Do not allow zero tiles. diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index a9370dc003830..1c3b621828315 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3401,10 +3401,7 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, SmallVector ubs = loop.getMixedUpperBound(); SmallVector steps = loop.getMixedStep(); - if (llvm::all_of( - lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) && - llvm::all_of( - steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) { + if (llvm::all_of(lbs, isZeroInteger) && llvm::all_of(steps, isOneInteger)) { return loop; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index b1340be04e011..a62510deefc4a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -441,8 +441,8 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, // If the `padOp` has a nofold attribute and all paddings are known to be 0, // explicitly insert a `linalg.copy`. if (padOp.getNofoldAttr() && - llvm::all_of(padOp.getMixedLowPad(), isZeroIndex) && - llvm::all_of(padOp.getMixedHighPad(), isZeroIndex)) { + llvm::all_of(padOp.getMixedLowPad(), isZeroInteger) && + llvm::all_of(padOp.getMixedHighPad(), isZeroInteger)) { using bufferization::AllocTensorOp; Value allocated = rewriter.create(loc, resultType, dynamicSizes); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 7c2788f16a3b6..4162aa0b71e6d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/SCF/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinOps.h" @@ -376,13 +377,13 @@ static void calculateTileOffsetsAndSizes( SmallVector threadIds = forallOp.getInductionVars(); SmallVector nonZeroNumThreads = llvm::filter_to_vector( - numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); }); + numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); }); int64_t nLoops = loopRanges.size(); tiledOffsets.reserve(nLoops); tiledSizes.reserve(nLoops); for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) { bool overflow = loopIdx >= numThreads.size(); - bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0); + bool isZero = !overflow && isZeroInteger(numThreads[loopIdx]); // Degenerate case: take the whole domain. if (overflow || isZero) { tiledOffsets.push_back(loopRanges[loopIdx].offset); @@ -413,7 +414,7 @@ static void calculateTileOffsetsAndSizes( OpFoldResult residualTileSize = makeComposedFoldedAffineApply( b, loc, i + j * m - n, {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size}); - if (!isConstantIntValue(residualTileSize, 0)) { + if (!isZeroInteger(residualTileSize)) { OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply( b, loc, -i + m, {offsetPerThread, size}); tileSizePerThread = @@ -655,7 +656,7 @@ FailureOr linalg::tileReductionUsingForall( Operation *tiledOp = nullptr; SmallVector nonZeroNumThreads = llvm::filter_to_vector( - numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); }); + numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); }); SmallVector materializedNonZeroNumThreads = getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads); diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index e8d460020cf69..7c14cc16437fe 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -369,7 +369,7 @@ struct LinalgOpPartialReductionInterface SmallVector tiledShape; for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) { - if (isZeroIndex(tileSize)) { + if (isZeroInteger(tileSize)) { tiledShape.push_back(dimSize); } else { tiledShape.push_back(tileSize); @@ -732,7 +732,7 @@ struct PackOpTiling // iterated or inner dims are not tiled. Otherwise, it will generate a // sequence of non-trivial ops (for partial tiles). for (auto offset : offsets.take_back(numTiles)) - if (!isConstantIntValue(offset, 0)) + if (!isZeroInteger(offset)) return failure(); for (auto iter : diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index d3d301ca093b1..bae06c003fd97 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -59,7 +59,7 @@ struct TileCheck : public AffineExprVisitor { TileCheck(ArrayRef tileSizes) : tileSizes(tileSizes) {} void visitDimExpr(AffineDimExpr expr) { - isTiled |= !isZeroIndex(tileSizes[expr.getPosition()]); + isTiled |= !isZeroInteger(tileSizes[expr.getPosition()]); } void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { visit(expr.getLHS()); @@ -741,7 +741,7 @@ SmallVector computeTileOffsets(OpBuilder &b, Location loc, SmallVector offsets; for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { LLVM_DEBUG(llvm::dbgs() << "makeTiledShapes: for loop#" << idx << "\n"); - bool isTiled = !isZeroIndex(tileSizes[idx]); + bool isTiled = !isZeroInteger(tileSizes[idx]); offsets.push_back(isTiled ? ivs[idxIvs++] : b.getIndexAttr(0)); LLVM_DEBUG(llvm::dbgs() << "computeTileOffsets: " << offsets.back() << "\n"); @@ -754,7 +754,7 @@ SmallVector computeTileSizes(OpBuilder &b, Location loc, ArrayRef sizeBounds) { SmallVector sizes; for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { - bool isTiled = !isZeroIndex(tileSizes[idx]); + bool isTiled = !isZeroInteger(tileSizes[idx]); // Before composing, we need to make range a closed interval. OpFoldResult size = isTiled ? tileSizes[idx] : sizeBounds[idx]; AffineExpr d0 = getAffineDimExpr(0, b.getContext()); @@ -810,7 +810,7 @@ computeAllSliceParameters(OpBuilder &builder, Location loc, LinalgOp linalgOp, bool omitPartialTileCheck) { assert(ivs.size() == static_cast(llvm::count_if( llvm::make_range(tileSizes.begin(), tileSizes.end()), - [](OpFoldResult v) { return !isZeroIndex(v); })) && + [](OpFoldResult v) { return !isZeroInteger(v); })) && "expected as many ivs as non-zero sizes"); // Construct (potentially temporary) mins and maxes on which to apply maps diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 82702789c2913..cab0ab8d15d5d 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1894,9 +1894,7 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) { // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets // are 0. if (auto prev = src.getDefiningOp()) - if (llvm::all_of(prev.getMixedOffsets(), [](OpFoldResult val) { - return isConstantIntValue(val, 0); - })) + if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger)) return prev.getSource(); return nullptr; @@ -3290,11 +3288,9 @@ OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) { auto srcSizes = srcSubview.getMixedSizes(); auto sizes = getMixedSizes(); auto offsets = getMixedOffsets(); - bool allOffsetsZero = llvm::all_of( - offsets, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }); + bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger); auto strides = getMixedStrides(); - bool allStridesOne = llvm::all_of( - strides, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); }); + bool allStridesOne = llvm::all_of(strides, isOneInteger); bool allSizesSame = llvm::equal(sizes, srcSizes); if (allOffsetsZero && allStridesOne && allSizesSame && resultMemrefType == sourceMemrefType) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp index 723b4d01186f9..2f5c9436fb8c7 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp @@ -251,9 +251,7 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern { // to do. SmallVector indices = getAsOpFoldResult(loadStoreLikeOp.getIndices()); - if (llvm::all_of(indices, [](const OpFoldResult &opFold) { - return isConstantIntValue(opFold, 0); - })) { + if (llvm::all_of(indices, isZeroInteger)) { return rewriter.notifyMatchFailure( loadStoreLikeOp, "no computation to extract: offsets are 0s"); } diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 0cd7da5db9163..719e2c6fa459e 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -133,7 +133,7 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op, tileSizes.resize(numLoops, zero); for (auto [index, range, nt] : llvm::enumerate(iterationDomain, numThreads)) { - if (isConstantIntValue(nt, 0)) + if (isZeroInteger(nt)) continue; tileSizes[index] = affine::makeComposedFoldedAffineApply( @@ -265,7 +265,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, // Non-tiled cases, set the offset and size to the // `loopRange.offset/size`. - if (isConstantIntValue(nt, 0)) { + if (isZeroInteger(nt)) { offsets.push_back(loopRange.offset); sizes.push_back(loopRange.size); continue; @@ -280,7 +280,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, {loopRange.offset, nt, tileSize, loopRange.size}); OpFoldResult size = tileSize; - if (!isConstantIntValue(residualTileSize, 0)) { + if (!isZeroInteger(residualTileSize)) { OpFoldResult sizeMinusOffsetPerThread = affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0, {offset, loopRange.size}); @@ -316,7 +316,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, // Non-tiled cases, set the offset and size to the // `loopRange.offset/size`. - if (isConstantIntValue(tileSize, 0)) { + if (isZeroInteger(tileSize)) { offsets.push_back(loopRange.offset); sizes.push_back(loopRange.size); continue; @@ -341,7 +341,7 @@ getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef loopRanges, SmallVector lbs, ubs, steps; for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) { // No loop if the tile size is 0. - if (isConstantIntValue(tileSize, 0)) + if (isZeroInteger(tileSize)) continue; lbs.push_back(loopRange.offset); ubs.push_back(loopRange.size); @@ -495,7 +495,7 @@ static LogicalResult generateLoopNestUsingForallOp( // Prune the zero numthreads. SmallVector nonZeroNumThreads; for (auto nt : numThreads) { - if (isConstantIntValue(nt, 0)) + if (isZeroInteger(nt)) continue; nonZeroNumThreads.push_back(nt); } @@ -551,7 +551,7 @@ static LogicalResult generateLoopNest( YieldTiledValuesFn tiledBodyFn, SmallVector &loops) { // If the tile sizes are all zero, no loops are generated. Just call the // callback function to handle untiled case. - if (llvm::all_of(tileSizes, isZeroIndex)) { + if (llvm::all_of(tileSizes, isZeroInteger)) { SmallVector tiledResults; SmallVector> resultOffsets, resultSizes; return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors, @@ -999,7 +999,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, // 5b. Early return cloned op if tiling is not happening. We can not // return the original op because it could lead to `rewriter.replaceOp(op, // op->getResults())` and users would get crash. - if (llvm::all_of(tileSizes, isZeroIndex)) { + if (llvm::all_of(tileSizes, isZeroInteger)) { tiledResults.append(clonedOp->result_begin(), clonedOp->result_end()); tilingResult = TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(), @@ -1290,9 +1290,7 @@ FailureOr> mlir::scf::yieldReplacementForFusedProducer( sliceSizes = sliceOp.getMixedSizes(); // expect all strides of sliceOp being 1 - if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { - return !isConstantIntValue(ofr, 1); - })) + if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger)) return failure(); unsigned sliceResultNumber = @@ -2114,9 +2112,7 @@ mlir::scf::tileAndFuseConsumerOfSlice( SmallVector strides = ossSliceOp.getMixedStrides(); // 9. Check all insert stride is 1. - if (llvm::any_of(strides, [](OpFoldResult stride) { - return !isConstantIntValue(stride, 1); - })) { + if (!llvm::all_of(strides, isOneInteger)) { return rewriter.notifyMatchFailure( candidateSliceOp, "containingOp's result yield with stride"); } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index d9550fe18dc02..8ab5bdc0c5dc5 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -768,7 +768,7 @@ static void denormalizeInductionVariableForIndexType(RewriterBase &rewriter, // If an `affine.apply` operation is generated for denormalization, the use // of `origLb` in those ops must not be replaced. These arent not generated // when `origLb == 0` and `origStep == 1`. - if (!isConstantIntValue(origLb, 0) || !isConstantIntValue(origStep, 1)) { + if (!isZeroInteger(origLb) || !isOneInteger(origStep)) { if (Operation *preservedUse = denormalizedIvVal.getDefiningOp()) { preservedUses.insert(preservedUse); } @@ -785,8 +785,8 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc, } Value denormalizedIv; SmallPtrSet preserve; - bool isStepOne = isConstantIntValue(origStep, 1); - bool isZeroBased = isConstantIntValue(origLb, 0); + bool isStepOne = isOneInteger(origStep); + bool isZeroBased = isZeroInteger(origLb); Value scaled = normalizedIv; if (!isStepOne) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index b2eca539194a8..3d963dea2f572 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -614,7 +614,7 @@ struct ForOpRewriter : public OpRewritePattern { // Check for single block, unit-stride for-loop that is generated by // sparsifier, which means no data dependence analysis is required, // and its loop-body is very restricted in form. - if (!op.getRegion().hasOneBlock() || !isConstantIntValue(op.getStep(), 1) || + if (!op.getRegion().hasOneBlock() || !isOneInteger(op.getStep()) || !op->hasAttr(LoopEmitter::getLoopEmitterLoopAttrName())) return failure(); // Analyze (!codegen) and rewrite (codegen) loop-body. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 9a0d5d7e16960..30ca20fc0d883 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -2839,8 +2839,7 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) { return getResult(); if (auto result = foldInsertAfterExtractSlice(*this)) return result; - if (llvm::any_of(getMixedSizes(), - [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); })) + if (llvm::any_of(getMixedSizes(), isZeroInteger)) return getDest(); return OpFoldResult(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index 7778a02dbeaf4..92540bd56ecbc 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -135,9 +135,9 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, SmallVector newStrides(rank, b.getIndexAttr(1)); for (unsigned dim = 0; dim < rank; ++dim) { auto low = padOp.getMixedLowPad()[dim]; - bool hasLowPad = !isConstantIntValue(low, 0); + bool hasLowPad = !isZeroInteger(low); auto high = padOp.getMixedHighPad()[dim]; - bool hasHighPad = !isConstantIntValue(high, 0); + bool hasHighPad = !isZeroInteger(high); auto offset = offsets[dim]; auto length = sizes[dim]; // If the dim has no padding, we dont need to calculate new values for that @@ -208,7 +208,7 @@ FailureOr tensor::bubbleUpPadSlice(OpBuilder &b, // Check if newLength is zero. In that case, no SubTensorOp should be // executed. - if (isConstantIntValue(newLength, 0)) { + if (isZeroInteger(newLength)) { hasZeroLen = true; } else if (!hasZeroLen) { Value check = b.create( diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 6525e58d002a2..b6843e560a899 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -646,7 +646,7 @@ static bool insertSliceOpRequiresRead(InsertOpTy insertSliceOp, // Dest is not read if it is entirely overwritten. E.g.: // tensor.insert_slice %a into %t[0][10][1] : ... into tensor<10xf32> bool allOffsetsZero = - llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroIndex); + llvm::all_of(insertSliceOp.getMixedOffsets(), isZeroInteger); RankedTensorType destType = insertSliceOp.getDestType(); bool sizesMatchDestSizes = areConstantIntValues(insertSliceOp.getMixedSizes(), destType.getShape()); diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index a3de7f9b44ae6..2b229d60c691b 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -452,7 +452,7 @@ struct BubbleUpExpandShapeThroughExtractSlice std::function isZeroOffsetAndFullSize = [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) { - if (!isConstantIntValue(offset, 0)) + if (!isZeroInteger(offset)) return false; FailureOr maybeEqual = ValueBoundsConstraintSet::areEqual(sliceSize, size); @@ -476,7 +476,7 @@ struct BubbleUpExpandShapeThroughExtractSlice // Find the first expanded dim after the first dim with non-unit extracted // size. for (; i < e; ++i) { - if (!isConstantIntValue(sizes[indices[i]], 1)) { + if (!isOneInteger(sizes[indices[i]])) { // +1 to skip the first non-unit size dim. i++; break; diff --git a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp index 858adfc436164..6f33f9b55ceb6 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp @@ -27,9 +27,7 @@ FailureOr tensor::replaceExtractSliceWithTiledProducer( return failure(); // `TilingInterface` currently only supports strides being 1. - if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { - return !isConstantIntValue(ofr, 1); - })) + if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger)) return failure(); FailureOr tiledResult = producerOp.generateResultTileValue( @@ -49,9 +47,7 @@ FailureOr tensor::replaceInsertSliceWithTiledConsumer( return failure(); // `TilingInterface` currently only supports strides being 1. - if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) { - return !isConstantIntValue(ofr, 1); - })) + if (!llvm::all_of(sliceOp.getMixedStrides(), isOneInteger)) return failure(); FailureOr tiledResult = diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index fac836ebd7a36..29f7bd6857c27 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -15,14 +15,9 @@ namespace mlir { -bool isZeroIndex(OpFoldResult v) { - if (!v) - return false; - std::optional constint = getConstantIntValue(v); - if (!constint) - return false; - return *constint == 0; -} +bool isZeroInteger(OpFoldResult v) { return isConstantIntValue(v, 0); } + +bool isOneInteger(OpFoldResult v) { return isConstantIntValue(v, 1); } std::tuple, SmallVector, SmallVector> diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 060ce7d1d6643..678a88627ca82 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -116,8 +116,7 @@ static bool stridesAllOne(TOp op) { std::is_same_v, "expected vector.extract_strided_slice or vector.insert_strided_slice"); ArrayAttr strides = op.getStrides(); - return llvm::all_of( - strides, [](auto stride) { return isConstantIntValue(stride, 1); }); + return llvm::all_of(strides, isOneInteger); } /// Convert an array of attributes into a vector of integers, if possible. diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index d4d07c7eadc77..7dbb7a334fe62 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -538,7 +538,7 @@ static SmallVector getCollapsedIndices(RewriterBase &rewriter, indices.begin(), indices.begin() + firstDimToCollapse); SmallVector indicesToCollapse(indices.begin() + firstDimToCollapse, indices.end()); - if (llvm::all_of(indicesToCollapse, isZeroIndex)) { + if (llvm::all_of(indicesToCollapse, isZeroInteger)) { indicesAfterCollapsing.push_back(indicesToCollapse[0]); return indicesAfterCollapsing; } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index c635be6e83b6a..71c557f7eda06 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1118,7 +1118,7 @@ class ExtractOpFromLoad final : public OpRewritePattern { ArithIndexingBuilder idxBuilderf(rewriter, loc); for (auto i : llvm::seq(rankOffset, indices.size() - finalRank)) { OpFoldResult pos = extractPos[i - rankOffset]; - if (isConstantIntValue(pos, 0)) + if (isZeroInteger(pos)) continue; Value offset = getValueOrCreateConstantIndexOp(rewriter, loc, pos);