66//
77// ===----------------------------------------------------------------------===//
88
9+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
910#include " mlir/Dialect/Arith/IR/Arith.h"
1011#include " mlir/Dialect/SCF/IR/SCF.h"
12+ #include " mlir/Dialect/SCF/Utils/Utils.h"
1113#include " mlir/IR/Diagnostics.h"
1214#include " mlir/IR/MLIRContext.h"
1315#include " mlir/IR/OwningOpRef.h"
16+ #include " mlir/IR/PatternMatch.h"
17+ #include " mlir/Interfaces/LoopLikeInterface.h"
1418#include " gtest/gtest.h"
1519
1620using namespace mlir ;
@@ -23,7 +27,7 @@ using namespace mlir::scf;
2327class SCFLoopLikeTest : public ::testing::Test {
2428protected:
2529 SCFLoopLikeTest () : b(&context), loc(UnknownLoc::get(&context)) {
26- context.loadDialect <arith::ArithDialect, scf::SCFDialect>();
30+ context.loadDialect <affine::AffineDialect, arith::ArithDialect, scf::SCFDialect>();
2731 }
2832
2933 void checkUnidimensional (LoopLikeOpInterface loopLikeOp) {
@@ -88,6 +92,24 @@ class SCFLoopLikeTest : public ::testing::Test {
8892 EXPECT_EQ ((*maybeInductionVars).size (), 2u );
8993 }
9094
95+ void checkNormalized (LoopLikeOpInterface loopLikeOp) {
96+ std::optional<SmallVector<OpFoldResult>> maybeLb =
97+ loopLikeOp.getLoopLowerBounds ();
98+ ASSERT_TRUE (maybeLb.has_value ());
99+ std::optional<SmallVector<OpFoldResult>> maybeStep =
100+ loopLikeOp.getLoopSteps ();
101+ ASSERT_TRUE (maybeStep.has_value ());
102+
103+ auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
104+ return llvm::all_of (results, [&](OpFoldResult ofr) {
105+ auto intValue = getConstantIntValue (ofr);
106+ return intValue.has_value () && intValue == val;
107+ });
108+ };
109+ EXPECT_TRUE (allEqual (*maybeLb, 0 ));
110+ EXPECT_TRUE (allEqual (*maybeStep, 1 ));
111+ }
112+
91113 MLIRContext context;
92114 OpBuilder b;
93115 Location loc;
@@ -138,3 +160,23 @@ TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
138160 ValueRange ({step->getResult (), step->getResult ()}), ValueRange ());
139161 checkMultidimensional (parallelOp.get ());
140162}
163+
164+ TEST_F (SCFLoopLikeTest, testForallNormalize) {
165+ OwningOpRef<arith::ConstantIndexOp> lb =
166+ b.create <arith::ConstantIndexOp>(loc, 1 );
167+ OwningOpRef<arith::ConstantIndexOp> ub =
168+ b.create <arith::ConstantIndexOp>(loc, 10 );
169+ OwningOpRef<arith::ConstantIndexOp> step =
170+ b.create <arith::ConstantIndexOp>(loc, 3 );
171+
172+ scf::ForallOp forallOp = b.create <scf::ForallOp>(
173+ loc, ArrayRef<OpFoldResult>({lb->getResult (), lb->getResult ()}),
174+ ArrayRef<OpFoldResult>({ub->getResult (), ub->getResult ()}),
175+ ArrayRef<OpFoldResult>({step->getResult (), step->getResult ()}),
176+ ValueRange (), std::nullopt );
177+ IRRewriter rewriter (b);
178+ FailureOr<scf::ForallOp> maybeNormalizedForallOp = normalizeForallOp (rewriter, forallOp);
179+ EXPECT_TRUE (succeeded (maybeNormalizedForallOp));
180+ OwningOpRef<scf::ForallOp> normalizedForallOp (*maybeNormalizedForallOp);
181+ checkNormalized (normalizedForallOp.get ());
182+ }
0 commit comments