@@ -133,7 +133,7 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
133133 tileSizes.resize (numLoops, zero);
134134 for (auto [index, range, nt] :
135135 llvm::enumerate (iterationDomain, numThreads)) {
136- if (isConstantIntValue (nt, 0 ))
136+ if (isZeroIndex (nt))
137137 continue ;
138138
139139 tileSizes[index] = affine::makeComposedFoldedAffineApply (
@@ -265,7 +265,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
265265
266266 // Non-tiled cases, set the offset and size to the
267267 // `loopRange.offset/size`.
268- if (isConstantIntValue (nt, 0 )) {
268+ if (isZeroIndex (nt)) {
269269 offsets.push_back (loopRange.offset );
270270 sizes.push_back (loopRange.size );
271271 continue ;
@@ -280,7 +280,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
280280 {loopRange.offset , nt, tileSize, loopRange.size });
281281
282282 OpFoldResult size = tileSize;
283- if (!isConstantIntValue (residualTileSize, 0 )) {
283+ if (!isZeroIndex (residualTileSize)) {
284284 OpFoldResult sizeMinusOffsetPerThread =
285285 affine::makeComposedFoldedAffineApply (rewriter, loc, s0 - d0,
286286 {offset, loopRange.size });
@@ -316,7 +316,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
316316
317317 // Non-tiled cases, set the offset and size to the
318318 // `loopRange.offset/size`.
319- if (isConstantIntValue (tileSize, 0 )) {
319+ if (isZeroIndex (tileSize)) {
320320 offsets.push_back (loopRange.offset );
321321 sizes.push_back (loopRange.size );
322322 continue ;
@@ -341,7 +341,7 @@ getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
341341 SmallVector<OpFoldResult> lbs, ubs, steps;
342342 for (auto [loopRange, tileSize] : llvm::zip_equal (loopRanges, tileSizes)) {
343343 // No loop if the tile size is 0.
344- if (isConstantIntValue (tileSize, 0 ))
344+ if (isZeroIndex (tileSize))
345345 continue ;
346346 lbs.push_back (loopRange.offset );
347347 ubs.push_back (loopRange.size );
@@ -495,7 +495,7 @@ static LogicalResult generateLoopNestUsingForallOp(
495495 // Prune the zero numthreads.
496496 SmallVector<OpFoldResult> nonZeroNumThreads;
497497 for (auto nt : numThreads) {
498- if (isConstantIntValue (nt, 0 ))
498+ if (isZeroIndex (nt))
499499 continue ;
500500 nonZeroNumThreads.push_back (nt);
501501 }
@@ -1290,9 +1290,7 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer(
12901290 sliceSizes = sliceOp.getMixedSizes ();
12911291
12921292 // expect all strides of sliceOp being 1
1293- if (llvm::any_of (sliceOp.getMixedStrides (), [](OpFoldResult ofr) {
1294- return !isConstantIntValue (ofr, 1 );
1295- }))
1293+ if (!llvm::all_of (sliceOp.getMixedStrides (), isOneIndex))
12961294 return failure ();
12971295
12981296 unsigned sliceResultNumber =
@@ -2114,9 +2112,7 @@ mlir::scf::tileAndFuseConsumerOfSlice(
21142112 SmallVector<OpFoldResult> strides = ossSliceOp.getMixedStrides ();
21152113
21162114 // 9. Check all insert stride is 1.
2117- if (llvm::any_of (strides, [](OpFoldResult stride) {
2118- return !isConstantIntValue (stride, 1 );
2119- })) {
2115+ if (!llvm::all_of (strides, isOneIndex)) {
21202116 return rewriter.notifyMatchFailure (
21212117 candidateSliceOp, " containingOp's result yield with stride" );
21222118 }
0 commit comments