Skip to content

Commit efa1760

Browse files
committed
[MLIR][SCF] Fix normalizeForallOp helper function
Previously the `normalizeForallOp` function did not work properly, since the newly created op was not being returned in addition to the op failing verification. The induction variable users inside the loop are also updated to use the denormalized variables. This patch fixes the helper function and adds a unit test for it.
1 parent 02139b1 commit efa1760

File tree

3 files changed

+83
-14
lines changed

3 files changed

+83
-14
lines changed

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,30 +1482,41 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
14821482
SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
14831483
SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
14841484

1485-
if (llvm::all_of(
1486-
lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
1487-
llvm::all_of(
1488-
steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
1485+
if (forallOp.isNormalized())
14891486
return forallOp;
1490-
}
14911487

1492-
SmallVector<OpFoldResult> newLbs, newUbs, newSteps;
1488+
OpBuilder::InsertionGuard g(rewriter);
1489+
auto loc = forallOp.getLoc();
1490+
rewriter.setInsertionPoint(forallOp);
1491+
SmallVector<OpFoldResult> newUbs;
14931492
for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
14941493
Range normalizedLoopParams =
1495-
emitNormalizedLoopBounds(rewriter, forallOp.getLoc(), lb, ub, step);
1496-
newLbs.push_back(normalizedLoopParams.offset);
1494+
emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
14971495
newUbs.push_back(normalizedLoopParams.size);
1498-
newSteps.push_back(normalizedLoopParams.stride);
14991496
}
1497+
(void)foldDynamicIndexList(newUbs);
15001498

1499+
// Use the normalized builder since the lower bounds are always 0 and the
1500+
// steps are always 1.
15011501
auto normalizedForallOp = rewriter.create<scf::ForallOp>(
1502-
forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(),
1503-
forallOp.getMapping(), [](OpBuilder &, Location, ValueRange) {});
1502+
loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(),
1503+
[](OpBuilder &, Location, ValueRange) {});
15041504

15051505
rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
15061506
normalizedForallOp.getBodyRegion(),
15071507
normalizedForallOp.getBodyRegion().begin());
1508+
// Remove the original empty block in the new loop.
1509+
rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());
1510+
1511+
rewriter.setInsertionPointToStart(normalizedForallOp.getBody());
1512+
// Update the users of the original loop variables.
1513+
for (auto [idx, iv] :
1514+
llvm::enumerate(normalizedForallOp.getInductionVars())) {
1515+
auto origLb = getValueOrCreateConstantIndexOp(rewriter, loc, lbs[idx]);
1516+
auto origStep = getValueOrCreateConstantIndexOp(rewriter, loc, steps[idx]);
1517+
denormalizeInductionVariable(rewriter, loc, iv, origLb, origStep);
1518+
}
15081519

1509-
rewriter.replaceAllOpUsesWith(forallOp, normalizedForallOp);
1510-
return success();
1520+
rewriter.replaceOp(forallOp, normalizedForallOp);
1521+
return normalizedForallOp;
15111522
}

mlir/unittests/Dialect/SCF/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ mlir_target_link_libraries(MLIRSCFTests
55
PRIVATE
66
MLIRIR
77
MLIRSCFDialect
8+
MLIRSCFUtils
89
)

mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
910
#include "mlir/Dialect/Arith/IR/Arith.h"
1011
#include "mlir/Dialect/SCF/IR/SCF.h"
12+
#include "mlir/Dialect/SCF/Utils/Utils.h"
13+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1114
#include "mlir/IR/Diagnostics.h"
1215
#include "mlir/IR/MLIRContext.h"
1316
#include "mlir/IR/OwningOpRef.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Interfaces/LoopLikeInterface.h"
1419
#include "gtest/gtest.h"
1520

1621
using namespace mlir;
@@ -23,7 +28,8 @@ using namespace mlir::scf;
2328
class SCFLoopLikeTest : public ::testing::Test {
2429
protected:
2530
SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
26-
context.loadDialect<arith::ArithDialect, scf::SCFDialect>();
31+
context.loadDialect<affine::AffineDialect, arith::ArithDialect,
32+
scf::SCFDialect>();
2733
}
2834

2935
void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -88,6 +94,24 @@ class SCFLoopLikeTest : public ::testing::Test {
8894
EXPECT_EQ((*maybeInductionVars).size(), 2u);
8995
}
9096

97+
void checkNormalized(LoopLikeOpInterface loopLikeOp) {
98+
std::optional<SmallVector<OpFoldResult>> maybeLb =
99+
loopLikeOp.getLoopLowerBounds();
100+
ASSERT_TRUE(maybeLb.has_value());
101+
std::optional<SmallVector<OpFoldResult>> maybeStep =
102+
loopLikeOp.getLoopSteps();
103+
ASSERT_TRUE(maybeStep.has_value());
104+
105+
auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
106+
return llvm::all_of(results, [&](OpFoldResult ofr) {
107+
auto intValue = getConstantIntValue(ofr);
108+
return intValue.has_value() && intValue == val;
109+
});
110+
};
111+
EXPECT_TRUE(allEqual(*maybeLb, 0));
112+
EXPECT_TRUE(allEqual(*maybeStep, 1));
113+
}
114+
91115
MLIRContext context;
92116
OpBuilder b;
93117
Location loc;
@@ -138,3 +162,36 @@ TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
138162
ValueRange({step->getResult(), step->getResult()}), ValueRange());
139163
checkMultidimensional(parallelOp.get());
140164
}
165+
166+
TEST_F(SCFLoopLikeTest, testForallNormalize) {
167+
OwningOpRef<arith::ConstantIndexOp> lb =
168+
b.create<arith::ConstantIndexOp>(loc, 1);
169+
OwningOpRef<arith::ConstantIndexOp> ub =
170+
b.create<arith::ConstantIndexOp>(loc, 10);
171+
OwningOpRef<arith::ConstantIndexOp> step =
172+
b.create<arith::ConstantIndexOp>(loc, 3);
173+
174+
scf::ForallOp forallOp = b.create<scf::ForallOp>(
175+
loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}),
176+
ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
177+
ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
178+
ValueRange(), std::nullopt);
179+
// Create a user of the induction variable. Bitcast is chosen for simplicity
180+
// since it is unary.
181+
b.setInsertionPointToStart(forallOp.getBody());
182+
b.create<arith::BitcastOp>(UnknownLoc::get(&context), b.getF64Type(),
183+
forallOp.getInductionVar(0));
184+
IRRewriter rewriter(b);
185+
FailureOr<scf::ForallOp> maybeNormalizedForallOp =
186+
normalizeForallOp(rewriter, forallOp);
187+
EXPECT_TRUE(succeeded(maybeNormalizedForallOp));
188+
OwningOpRef<scf::ForallOp> normalizedForallOp(*maybeNormalizedForallOp);
189+
checkNormalized(normalizedForallOp.get());
190+
191+
// Check that the IV user has been updated to use the denormalized variable.
192+
Block *body = normalizedForallOp->getBody();
193+
auto bitcastOps = body->getOps<arith::BitcastOp>();
194+
ASSERT_EQ(std::distance(bitcastOps.begin(), bitcastOps.end()), 1);
195+
arith::BitcastOp ivUser = *bitcastOps.begin();
196+
ASSERT_NE(ivUser.getIn(), normalizedForallOp->getInductionVar(0));
197+
}

0 commit comments

Comments
 (0)