Skip to content

Commit e1acacc

Browse files
committed
[MLIR][SCF] Fix normalizeForallOp helper function
Previously the `normalizeForallOp` function did not work properly, since the newly created op was not being returned in addition to the op failing verification. This patch fixes the helper function and adds a unit test for it.
1 parent e0a951f commit e1acacc

File tree

3 files changed

+57
-13
lines changed

3 files changed

+57
-13
lines changed

mlir/lib/Dialect/SCF/Utils/Utils.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,30 +1482,31 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
14821482
SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
14831483
SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
14841484

1485-
if (llvm::all_of(
1486-
lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
1487-
llvm::all_of(
1488-
steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
1485+
if (forallOp.isNormalized())
14891486
return forallOp;
1490-
}
14911487

1492-
SmallVector<OpFoldResult> newLbs, newUbs, newSteps;
1488+
OpBuilder::InsertionGuard g(rewriter);
1489+
rewriter.setInsertionPoint(forallOp);
1490+
SmallVector<OpFoldResult> newUbs;
14931491
for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
14941492
Range normalizedLoopParams =
14951493
emitNormalizedLoopBounds(rewriter, forallOp.getLoc(), lb, ub, step);
1496-
newLbs.push_back(normalizedLoopParams.offset);
14971494
newUbs.push_back(normalizedLoopParams.size);
1498-
newSteps.push_back(normalizedLoopParams.stride);
14991495
}
1496+
(void)foldDynamicIndexList(newUbs);
15001497

1498+
// Use the normalized builder since the lower bounds are always 0 and the
1499+
// steps are always 1.
15011500
auto normalizedForallOp = rewriter.create<scf::ForallOp>(
1502-
forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(),
1503-
forallOp.getMapping(), [](OpBuilder &, Location, ValueRange) {});
1501+
forallOp.getLoc(), newUbs, forallOp.getOutputs(), forallOp.getMapping(),
1502+
[](OpBuilder &, Location, ValueRange) {});
15041503

15051504
rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
15061505
normalizedForallOp.getBodyRegion(),
15071506
normalizedForallOp.getBodyRegion().begin());
1507+
// Remove the original empty block in the new loop.
1508+
rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());
15081509

1509-
rewriter.replaceAllOpUsesWith(forallOp, normalizedForallOp);
1510-
return success();
1510+
rewriter.replaceOp(forallOp, normalizedForallOp);
1511+
return normalizedForallOp;
15111512
}

mlir/unittests/Dialect/SCF/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ mlir_target_link_libraries(MLIRSCFTests
55
PRIVATE
66
MLIRIR
77
MLIRSCFDialect
8+
MLIRSCFUtils
89
)

mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,15 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
910
#include "mlir/Dialect/Arith/IR/Arith.h"
1011
#include "mlir/Dialect/SCF/IR/SCF.h"
12+
#include "mlir/Dialect/SCF/Utils/Utils.h"
1113
#include "mlir/IR/Diagnostics.h"
1214
#include "mlir/IR/MLIRContext.h"
1315
#include "mlir/IR/OwningOpRef.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
#include "mlir/Interfaces/LoopLikeInterface.h"
1418
#include "gtest/gtest.h"
1519

1620
using namespace mlir;
@@ -23,7 +27,7 @@ using namespace mlir::scf;
2327
class SCFLoopLikeTest : public ::testing::Test {
2428
protected:
2529
SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
26-
context.loadDialect<arith::ArithDialect, scf::SCFDialect>();
30+
context.loadDialect<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect>();
2731
}
2832

2933
void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -88,6 +92,24 @@ class SCFLoopLikeTest : public ::testing::Test {
8892
EXPECT_EQ((*maybeInductionVars).size(), 2u);
8993
}
9094

95+
void checkNormalized(LoopLikeOpInterface loopLikeOp) {
96+
std::optional<SmallVector<OpFoldResult>> maybeLb =
97+
loopLikeOp.getLoopLowerBounds();
98+
ASSERT_TRUE(maybeLb.has_value());
99+
std::optional<SmallVector<OpFoldResult>> maybeStep =
100+
loopLikeOp.getLoopSteps();
101+
ASSERT_TRUE(maybeStep.has_value());
102+
103+
auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
104+
return llvm::all_of(results, [&](OpFoldResult ofr) {
105+
auto intValue = getConstantIntValue(ofr);
106+
return intValue.has_value() && intValue == val;
107+
});
108+
};
109+
EXPECT_TRUE(allEqual(*maybeLb, 0));
110+
EXPECT_TRUE(allEqual(*maybeStep, 1));
111+
}
112+
91113
MLIRContext context;
92114
OpBuilder b;
93115
Location loc;
@@ -138,3 +160,23 @@ TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
138160
ValueRange({step->getResult(), step->getResult()}), ValueRange());
139161
checkMultidimensional(parallelOp.get());
140162
}
163+
164+
TEST_F(SCFLoopLikeTest, testForallNormalize) {
165+
OwningOpRef<arith::ConstantIndexOp> lb =
166+
b.create<arith::ConstantIndexOp>(loc, 1);
167+
OwningOpRef<arith::ConstantIndexOp> ub =
168+
b.create<arith::ConstantIndexOp>(loc, 10);
169+
OwningOpRef<arith::ConstantIndexOp> step =
170+
b.create<arith::ConstantIndexOp>(loc, 3);
171+
172+
scf::ForallOp forallOp = b.create<scf::ForallOp>(
173+
loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}),
174+
ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
175+
ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
176+
ValueRange(), std::nullopt);
177+
IRRewriter rewriter(b);
178+
FailureOr<scf::ForallOp> maybeNormalizedForallOp = normalizeForallOp(rewriter, forallOp);
179+
EXPECT_TRUE(succeeded(maybeNormalizedForallOp));
180+
OwningOpRef<scf::ForallOp> normalizedForallOp(*maybeNormalizedForallOp);
181+
checkNormalized(normalizedForallOp.get());
182+
}

0 commit comments

Comments
 (0)