From efa17605abf6fe0022865a581fb17913a29fe08a Mon Sep 17 00:00:00 2001 From: Colin De Vlieghere Date: Mon, 5 May 2025 19:10:24 -0700 Subject: [PATCH] [MLIR][SCF] Fix normalizeForallOp helper function Previously the `normalizeForallOp` function did not work properly, since the newly created op was not being returned in addition to the op failing verification. The induction variable users inside the loop are also updated to use the denormalized variables. This patch fixes the helper function and adds a unit test for it. --- mlir/lib/Dialect/SCF/Utils/Utils.cpp | 37 ++++++++---- mlir/unittests/Dialect/SCF/CMakeLists.txt | 1 + .../Dialect/SCF/LoopLikeSCFOpsTest.cpp | 59 ++++++++++++++++++- 3 files changed, 83 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index e9471c1dbd0b7..d9550fe18dc02 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1482,30 +1482,41 @@ FailureOr mlir::normalizeForallOp(RewriterBase &rewriter, SmallVector ubs = forallOp.getMixedUpperBound(); SmallVector steps = forallOp.getMixedStep(); - if (llvm::all_of( - lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) && - llvm::all_of( - steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) { + if (forallOp.isNormalized()) return forallOp; - } - SmallVector newLbs, newUbs, newSteps; + OpBuilder::InsertionGuard g(rewriter); + auto loc = forallOp.getLoc(); + rewriter.setInsertionPoint(forallOp); + SmallVector newUbs; for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) { Range normalizedLoopParams = - emitNormalizedLoopBounds(rewriter, forallOp.getLoc(), lb, ub, step); - newLbs.push_back(normalizedLoopParams.offset); + emitNormalizedLoopBounds(rewriter, loc, lb, ub, step); newUbs.push_back(normalizedLoopParams.size); - newSteps.push_back(normalizedLoopParams.stride); } + (void)foldDynamicIndexList(newUbs); + // Use the normalized builder since the lower bounds are always 0 and the + // steps are always 1. auto normalizedForallOp = rewriter.create( - forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(), - forallOp.getMapping(), [](OpBuilder &, Location, ValueRange) {}); + loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(), + [](OpBuilder &, Location, ValueRange) {}); rewriter.inlineRegionBefore(forallOp.getBodyRegion(), normalizedForallOp.getBodyRegion(), normalizedForallOp.getBodyRegion().begin()); + // Remove the original empty block in the new loop. + rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back()); + + rewriter.setInsertionPointToStart(normalizedForallOp.getBody()); + // Update the users of the original loop variables. + for (auto [idx, iv] : + llvm::enumerate(normalizedForallOp.getInductionVars())) { + auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]); + auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]); + denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep); + } - rewriter.replaceAllOpUsesWith(forallOp, normalizedForallOp); - return success(); + rewriter.replaceOp(forallOp, normalizedForallOp); + return normalizedForallOp; } diff --git a/mlir/unittests/Dialect/SCF/CMakeLists.txt b/mlir/unittests/Dialect/SCF/CMakeLists.txt index c0c1757b80fb5..83cefbcabf4d9 100644 --- a/mlir/unittests/Dialect/SCF/CMakeLists.txt +++ b/mlir/unittests/Dialect/SCF/CMakeLists.txt @@ -5,4 +5,5 @@ mlir_target_link_libraries(MLIRSCFTests PRIVATE MLIRIR MLIRSCFDialect + MLIRSCFUtils ) diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp index 53a4af14d119a..fecd960d694b1 100644 --- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp +++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp @@ -6,11 +6,16 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OwningOpRef.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/LoopLikeInterface.h" #include "gtest/gtest.h" using namespace mlir; @@ -23,7 +28,8 @@ using namespace mlir::scf; class SCFLoopLikeTest : public ::testing::Test { protected: SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) { - context.loadDialect(); + context.loadDialect(); } void checkUnidimensional(LoopLikeOpInterface loopLikeOp) { @@ -88,6 +94,24 @@ class SCFLoopLikeTest : public ::testing::Test { EXPECT_EQ((*maybeInductionVars).size(), 2u); } + void checkNormalized(LoopLikeOpInterface loopLikeOp) { + std::optional> maybeLb = + loopLikeOp.getLoopLowerBounds(); + ASSERT_TRUE(maybeLb.has_value()); + std::optional> maybeStep = + loopLikeOp.getLoopSteps(); + ASSERT_TRUE(maybeStep.has_value()); + + auto allEqual = [](ArrayRef results, int64_t val) { + return llvm::all_of(results, [&](OpFoldResult ofr) { + auto intValue = getConstantIntValue(ofr); + return intValue.has_value() && intValue == val; + }); + }; + EXPECT_TRUE(allEqual(*maybeLb, 0)); + EXPECT_TRUE(allEqual(*maybeStep, 1)); + } + MLIRContext context; OpBuilder b; Location loc; @@ -138,3 +162,36 @@ TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) { ValueRange({step->getResult(), step->getResult()}), ValueRange()); checkMultidimensional(parallelOp.get()); } + +TEST_F(SCFLoopLikeTest, testForallNormalize) { + OwningOpRef lb = + b.create(loc, 1); + OwningOpRef ub = + b.create(loc, 10); + OwningOpRef step = + b.create(loc, 3); + + scf::ForallOp forallOp = b.create( + loc, ArrayRef({lb->getResult(), lb->getResult()}), + ArrayRef({ub->getResult(), ub->getResult()}), + ArrayRef({step->getResult(), step->getResult()}), + ValueRange(), std::nullopt); + // Create a user of the induction variable. Bitcast is chosen for simplicity + // since it is unary. + b.setInsertionPointToStart(forallOp.getBody()); + b.create(UnknownLoc::get(&context), b.getF64Type(), + forallOp.getInductionVar(0)); + IRRewriter rewriter(b); + FailureOr maybeNormalizedForallOp = + normalizeForallOp(rewriter, forallOp); + EXPECT_TRUE(succeeded(maybeNormalizedForallOp)); + OwningOpRef normalizedForallOp(*maybeNormalizedForallOp); + checkNormalized(normalizedForallOp.get()); + + // Check that the IV user has been updated to use the denormalized variable. + Block *body = normalizedForallOp->getBody(); + auto bitcastOps = body->getOps(); + ASSERT_EQ(std::distance(bitcastOps.begin(), bitcastOps.end()), 1); + arith::BitcastOp ivUser = *bitcastOps.begin(); + ASSERT_NE(ivUser.getIn(), normalizedForallOp->getInductionVar(0)); +}