Skip to content

Commit 9ca4664

Browse files
authored
[MLIR][SCF] Fix normalizeForallOp helper function (#138615)
Previously the `normalizeForallOp` function did not work properly, since the newly created op was not being returned in addition to the op failing verification. This patch fixes the helper function and adds a unit test for it.
1 parent 31fd77a commit 9ca4664

File tree

3 files changed

+83
-14
lines changed

3 files changed

+83
-14
lines changed

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,30 +1482,41 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
14821482
SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
14831483
SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
14841484

1485-
if (llvm::all_of(
1486-
lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
1487-
llvm::all_of(
1488-
steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
1485+
if (forallOp.isNormalized())
14891486
return forallOp;
1490-
}
14911487

1492-
SmallVector<OpFoldResult> newLbs, newUbs, newSteps;
1488+
OpBuilder::InsertionGuard g(rewriter);
1489+
auto loc = forallOp.getLoc();
1490+
rewriter.setInsertionPoint(forallOp);
1491+
SmallVector<OpFoldResult> newUbs;
14931492
for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
14941493
Range normalizedLoopParams =
1495-
emitNormalizedLoopBounds(rewriter, forallOp.getLoc(), lb, ub, step);
1496-
newLbs.push_back(normalizedLoopParams.offset);
1494+
emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
14971495
newUbs.push_back(normalizedLoopParams.size);
1498-
newSteps.push_back(normalizedLoopParams.stride);
14991496
}
1497+
(void)foldDynamicIndexList(newUbs);
15001498

1499+
// Use the normalized builder since the lower bounds are always 0 and the
1500+
// steps are always 1.
15011501
auto normalizedForallOp = rewriter.create<scf::ForallOp>(
1502-
forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(),
1503-
forallOp.getMapping(), [](OpBuilder &, Location, ValueRange) {});
1502+
loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
1503+
[](OpBuilder &, Location, ValueRange) {});
15041504

15051505
rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
15061506
normalizedForallOp.getBodyRegion(),
15071507
normalizedForallOp.getBodyRegion().begin());
1508+
// Remove the original empty block in the new loop.
1509+
rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());
1510+
1511+
rewriter.setInsertionPointToStart(normalizedForallOp.getBody());
1512+
// Update the users of the original loop variables.
1513+
for (auto [idx, iv] :
1514+
llvm::enumerate(normalizedForallOp.getInductionVars())) {
1515+
auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]);
1516+
auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]);
1517+
denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep);
1518+
}
15081519

1509-
rewriter.replaceAllOpUsesWith(forallOp, normalizedForallOp);
1510-
return success();
1520+
rewriter.replaceOp(forallOp, normalizedForallOp);
1521+
return normalizedForallOp;
15111522
}

mlir/unittests/Dialect/SCF/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ mlir_target_link_libraries(MLIRSCFTests
55
PRIVATE
66
MLIRIR
77
MLIRSCFDialect
8+
MLIRSCFUtils
89
)

mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
910
#include "mlir/Dialect/Arith/IR/Arith.h"
1011
#include "mlir/Dialect/SCF/IR/SCF.h"
12+
#include "mlir/Dialect/SCF/Utils/Utils.h"
13+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1114
#include "mlir/IR/Diagnostics.h"
1215
#include "mlir/IR/MLIRContext.h"
1316
#include "mlir/IR/OwningOpRef.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Interfaces/LoopLikeInterface.h"
1419
#include "gtest/gtest.h"
1520

1621
using namespace mlir;
@@ -23,7 +28,8 @@ using namespace mlir::scf;
2328
class SCFLoopLikeTest : public ::testing::Test {
2429
protected:
2530
SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
26-
context.loadDialect<arith::ArithDialect, scf::SCFDialect>();
31+
context.loadDialect<affine::AffineDialect, arith::ArithDialect,
32+
scf::SCFDialect>();
2733
}
2834

2935
void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -88,6 +94,24 @@ class SCFLoopLikeTest : public ::testing::Test {
8894
EXPECT_EQ((*maybeInductionVars).size(), 2u);
8995
}
9096

97+
void checkNormalized(LoopLikeOpInterface loopLikeOp) {
98+
std::optional<SmallVector<OpFoldResult>> maybeLb =
99+
loopLikeOp.getLoopLowerBounds();
100+
ASSERT_TRUE(maybeLb.has_value());
101+
std::optional<SmallVector<OpFoldResult>> maybeStep =
102+
loopLikeOp.getLoopSteps();
103+
ASSERT_TRUE(maybeStep.has_value());
104+
105+
auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
106+
return llvm::all_of(results, [&](OpFoldResult ofr) {
107+
auto intValue = getConstantIntValue(ofr);
108+
return intValue.has_value() && intValue == val;
109+
});
110+
};
111+
EXPECT_TRUE(allEqual(*maybeLb, 0));
112+
EXPECT_TRUE(allEqual(*maybeStep, 1));
113+
}
114+
91115
MLIRContext context;
92116
OpBuilder b;
93117
Location loc;
@@ -138,3 +162,36 @@ TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
138162
ValueRange({step->getResult(), step->getResult()}), ValueRange());
139163
checkMultidimensional(parallelOp.get());
140164
}
165+
166+
TEST_F(SCFLoopLikeTest, testForallNormalize) {
167+
OwningOpRef<arith::ConstantIndexOp> lb =
168+
b.create<arith::ConstantIndexOp>(loc, 1);
169+
OwningOpRef<arith::ConstantIndexOp> ub =
170+
b.create<arith::ConstantIndexOp>(loc, 10);
171+
OwningOpRef<arith::ConstantIndexOp> step =
172+
b.create<arith::ConstantIndexOp>(loc, 3);
173+
174+
scf::ForallOp forallOp = b.create<scf::ForallOp>(
175+
loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}),
176+
ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
177+
ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
178+
ValueRange(), std::nullopt);
179+
// Create a user of the induction variable. Bitcast is chosen for simplicity
180+
// since it is unary.
181+
b.setInsertionPointToStart(forallOp.getBody());
182+
b.create<arith::BitcastOp>(UnknownLoc::get(&context), b.getF64Type(),
183+
forallOp.getInductionVar(0));
184+
IRRewriter rewriter(b);
185+
FailureOr<scf::ForallOp> maybeNormalizedForallOp =
186+
normalizeForallOp(rewriter, forallOp);
187+
EXPECT_TRUE(succeeded(maybeNormalizedForallOp));
188+
OwningOpRef<scf::ForallOp> normalizedForallOp(*maybeNormalizedForallOp);
189+
checkNormalized(normalizedForallOp.get());
190+
191+
// Check that the IV user has been updated to use the denormalized variable.
192+
Block *body = normalizedForallOp->getBody();
193+
auto bitcastOps = body->getOps<arith::BitcastOp>();
194+
ASSERT_EQ(std::distance(bitcastOps.begin(), bitcastOps.end()), 1);
195+
arith::BitcastOp ivUser = *bitcastOps.begin();
196+
ASSERT_NE(ivUser.getIn(), normalizedForallOp->getInductionVar(0));
197+
}

0 commit comments

Comments
 (0)