| 
23 | 23 | #include "mlir/Dialect/SCF/Transforms/Transforms.h"  | 
24 | 24 | #include "mlir/Dialect/Tensor/IR/Tensor.h"  | 
25 | 25 | #include "mlir/Dialect/Utils/IndexingUtils.h"  | 
 | 26 | +#include "mlir/Dialect/Utils/StaticValueUtils.h"  | 
26 | 27 | #include "mlir/IR/AffineExpr.h"  | 
27 | 28 | #include "mlir/IR/AffineMap.h"  | 
28 | 29 | #include "mlir/IR/BuiltinOps.h"  | 
@@ -376,13 +377,13 @@ static void calculateTileOffsetsAndSizes(  | 
376 | 377 | 
 
  | 
377 | 378 |   SmallVector<Value> threadIds = forallOp.getInductionVars();  | 
378 | 379 |   SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(  | 
379 |  | -      numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });  | 
 | 380 | +      numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); });  | 
380 | 381 |   int64_t nLoops = loopRanges.size();  | 
381 | 382 |   tiledOffsets.reserve(nLoops);  | 
382 | 383 |   tiledSizes.reserve(nLoops);  | 
383 | 384 |   for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) {  | 
384 | 385 |     bool overflow = loopIdx >= numThreads.size();  | 
385 |  | -    bool isZero = !overflow && isConstantIntValue(numThreads[loopIdx], 0);  | 
 | 386 | +    bool isZero = !overflow && isZeroInteger(numThreads[loopIdx]);  | 
386 | 387 |     // Degenerate case: take the whole domain.  | 
387 | 388 |     if (overflow || isZero) {  | 
388 | 389 |       tiledOffsets.push_back(loopRanges[loopIdx].offset);  | 
@@ -413,7 +414,7 @@ static void calculateTileOffsetsAndSizes(  | 
413 | 414 |     OpFoldResult residualTileSize = makeComposedFoldedAffineApply(  | 
414 | 415 |         b, loc, i + j * m - n,  | 
415 | 416 |         {offset, nonZeroNumThreads[threadIdIdx], tileSizePerThread, size});  | 
416 |  | -    if (!isConstantIntValue(residualTileSize, 0)) {  | 
 | 417 | +    if (!isZeroInteger(residualTileSize)) {  | 
417 | 418 |       OpFoldResult sizeMinusOffsetPerThread = makeComposedFoldedAffineApply(  | 
418 | 419 |           b, loc, -i + m, {offsetPerThread, size});  | 
419 | 420 |       tileSizePerThread =  | 
@@ -655,7 +656,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(  | 
655 | 656 |   Operation *tiledOp = nullptr;  | 
656 | 657 | 
 
  | 
657 | 658 |   SmallVector<OpFoldResult> nonZeroNumThreads = llvm::filter_to_vector(  | 
658 |  | -      numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); });  | 
 | 659 | +      numThreads, [](OpFoldResult ofr) { return !isZeroInteger(ofr); });  | 
659 | 660 |   SmallVector<Value> materializedNonZeroNumThreads =  | 
660 | 661 |       getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads);  | 
661 | 662 | 
 
  | 
 | 
0 commit comments