diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index e9471c1dbd0b7..d9550fe18dc02 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1482,30 +1482,41 @@ FailureOr mlir::normalizeForallOp(RewriterBase &rewriter, SmallVector ubs = forallOp.getMixedUpperBound(); SmallVector steps = forallOp.getMixedStep(); - if (llvm::all_of( - lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) && - llvm::all_of( - steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) { + if (forallOp.isNormalized()) return forallOp; - } - SmallVector newLbs, newUbs, newSteps; + OpBuilder::InsertionGuard g(rewriter); + auto loc = forallOp.getLoc(); + rewriter.setInsertionPoint(forallOp); + SmallVector newUbs; for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) { Range normalizedLoopParams = - emitNormalizedLoopBounds(rewriter, forallOp.getLoc(), lb, ub, step); - newLbs.push_back(normalizedLoopParams.offset); + emitNormalizedLoopBounds(rewriter, loc, lb, ub, step); newUbs.push_back(normalizedLoopParams.size); - newSteps.push_back(normalizedLoopParams.stride); } + (void)foldDynamicIndexList(newUbs); + // Use the normalized builder since the lower bounds are always 0 and the + // steps are always 1. auto normalizedForallOp = rewriter.create( - forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(), - forallOp.getMapping(), [](OpBuilder &, Location, ValueRange) {}); + loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(), + [](OpBuilder &, Location, ValueRange) {}); rewriter.inlineRegionBefore(forallOp.getBodyRegion(), normalizedForallOp.getBodyRegion(), normalizedForallOp.getBodyRegion().begin()); + // Remove the original empty block in the new loop. + rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back()); + + rewriter.setInsertionPointToStart(normalizedForallOp.getBody()); + // Update the users of the original loop variables. + for (auto [idx, iv] : + llvm::enumerate(normalizedForallOp.getInductionVars())) { + auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]); + auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]); + denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep); + } - rewriter.replaceAllOpUsesWith(forallOp, normalizedForallOp); - return success(); + rewriter.replaceOp(forallOp, normalizedForallOp); + return normalizedForallOp; } diff --git a/mlir/unittests/Dialect/SCF/CMakeLists.txt b/mlir/unittests/Dialect/SCF/CMakeLists.txt index c0c1757b80fb5..83cefbcabf4d9 100644 --- a/mlir/unittests/Dialect/SCF/CMakeLists.txt +++ b/mlir/unittests/Dialect/SCF/CMakeLists.txt @@ -5,4 +5,5 @@ mlir_target_link_libraries(MLIRSCFTests PRIVATE MLIRIR MLIRSCFDialect + MLIRSCFUtils ) diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp index 53a4af14d119a..fecd960d694b1 100644 --- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp +++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp @@ -6,11 +6,16 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/LoopLikeInterface.h" #include "gtest/gtest.h" using namespace mlir; @@ -23,7 +28,8 @@ using namespace mlir::scf; class SCFLoopLikeTest : public ::testing::Test { protected: SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) { - context.loadDialect(); + context.loadDialect(); } void checkUnidimensional(LoopLikeOpInterface loopLikeOp) { @@ -88,6 +94,24 @@ class SCFLoopLikeTest : public ::testing::Test { EXPECT_EQ((*maybeInductionVars).size(), 2u); } + void checkNormalized(LoopLikeOpInterface loopLikeOp) { + std::optional> maybeLb = + loopLikeOp.getLoopLowerBounds(); + ASSERT_TRUE(maybeLb.has_value()); + std::optional> maybeStep = + loopLikeOp.getLoopSteps(); + ASSERT_TRUE(maybeStep.has_value()); + + auto allEqual = [](ArrayRef results, int64_t val) { + return llvm::all_of(results, [&](OpFoldResult ofr) { + auto intValue = getConstantIntValue(ofr); + return intValue.has_value() && intValue == val; + }); + }; + EXPECT_TRUE(allEqual(*maybeLb, 0)); + EXPECT_TRUE(allEqual(*maybeStep, 1)); + } + MLIRContext context; OpBuilder b; Location loc; @@ -138,3 +162,36 @@ TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) { ValueRange({step->getResult(), step->getResult()}), ValueRange()); checkMultidimensional(parallelOp.get()); } + +TEST_F(SCFLoopLikeTest, testForallNormalize) { + OwningOpRef lb = + b.create(loc, 1); + OwningOpRef ub = + b.create(loc, 10); + OwningOpRef step = + b.create(loc, 3); + + scf::ForallOp forallOp = b.create( + loc, ArrayRef({lb->getResult(), lb->getResult()}), + ArrayRef({ub->getResult(), ub->getResult()}), + ArrayRef({step->getResult(), step->getResult()}), + ValueRange(), std::nullopt); + // Create a user of the induction variable. Bitcast is chosen for simplicity + // since it is unary. + b.setInsertionPointToStart(forallOp.getBody()); + b.create(UnknownLoc::get(&context), b.getF64Type(), + forallOp.getInductionVar(0)); + IRRewriter rewriter(b); + FailureOr maybeNormalizedForallOp = + normalizeForallOp(rewriter, forallOp); + EXPECT_TRUE(succeeded(maybeNormalizedForallOp)); + OwningOpRef normalizedForallOp(*maybeNormalizedForallOp); + checkNormalized(normalizedForallOp.get()); + + // Check that the IV user has been updated to use the denormalized variable. + Block *body = normalizedForallOp->getBody(); + auto bitcastOps = body->getOps(); + ASSERT_EQ(std::distance(bitcastOps.begin(), bitcastOps.end()), 1); + arith::BitcastOp ivUser = *bitcastOps.begin(); + ASSERT_NE(ivUser.getIn(), normalizedForallOp->getInductionVar(0)); +}