6
6
//
7
7
// ===----------------------------------------------------------------------===//
8
8
9
+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
9
10
#include " mlir/Dialect/Arith/IR/Arith.h"
10
11
#include " mlir/Dialect/SCF/IR/SCF.h"
12
+ #include " mlir/Dialect/SCF/Utils/Utils.h"
13
+ #include " mlir/Dialect/Utils/StaticValueUtils.h"
11
14
#include " mlir/IR/Diagnostics.h"
12
15
#include " mlir/IR/MLIRContext.h"
13
16
#include " mlir/IR/OwningOpRef.h"
17
+ #include " mlir/IR/PatternMatch.h"
18
+ #include " mlir/Interfaces/LoopLikeInterface.h"
14
19
#include " gtest/gtest.h"
15
20
16
21
using namespace mlir ;
@@ -23,7 +28,8 @@ using namespace mlir::scf;
23
28
class SCFLoopLikeTest : public ::testing::Test {
24
29
protected:
25
30
SCFLoopLikeTest () : b(&context), loc(UnknownLoc::get(&context)) {
26
- context.loadDialect <arith::ArithDialect, scf::SCFDialect>();
31
+ context.loadDialect <affine::AffineDialect, arith::ArithDialect,
32
+ scf::SCFDialect>();
27
33
}
28
34
29
35
void checkUnidimensional (LoopLikeOpInterface loopLikeOp) {
@@ -88,6 +94,24 @@ class SCFLoopLikeTest : public ::testing::Test {
88
94
EXPECT_EQ ((*maybeInductionVars).size (), 2u );
89
95
}
90
96
97
+ void checkNormalized (LoopLikeOpInterface loopLikeOp) {
98
+ std::optional<SmallVector<OpFoldResult>> maybeLb =
99
+ loopLikeOp.getLoopLowerBounds ();
100
+ ASSERT_TRUE (maybeLb.has_value ());
101
+ std::optional<SmallVector<OpFoldResult>> maybeStep =
102
+ loopLikeOp.getLoopSteps ();
103
+ ASSERT_TRUE (maybeStep.has_value ());
104
+
105
+ auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
106
+ return llvm::all_of (results, [&](OpFoldResult ofr) {
107
+ auto intValue = getConstantIntValue (ofr);
108
+ return intValue.has_value () && intValue == val;
109
+ });
110
+ };
111
+ EXPECT_TRUE (allEqual (*maybeLb, 0 ));
112
+ EXPECT_TRUE (allEqual (*maybeStep, 1 ));
113
+ }
114
+
91
115
MLIRContext context;
92
116
OpBuilder b;
93
117
Location loc;
@@ -138,3 +162,36 @@ TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
138
162
ValueRange ({step->getResult (), step->getResult ()}), ValueRange ());
139
163
checkMultidimensional (parallelOp.get ());
140
164
}
165
+
166
+ TEST_F (SCFLoopLikeTest, testForallNormalize) {
167
+ OwningOpRef<arith::ConstantIndexOp> lb =
168
+ b.create <arith::ConstantIndexOp>(loc, 1 );
169
+ OwningOpRef<arith::ConstantIndexOp> ub =
170
+ b.create <arith::ConstantIndexOp>(loc, 10 );
171
+ OwningOpRef<arith::ConstantIndexOp> step =
172
+ b.create <arith::ConstantIndexOp>(loc, 3 );
173
+
174
+ scf::ForallOp forallOp = b.create <scf::ForallOp>(
175
+ loc, ArrayRef<OpFoldResult>({lb->getResult (), lb->getResult ()}),
176
+ ArrayRef<OpFoldResult>({ub->getResult (), ub->getResult ()}),
177
+ ArrayRef<OpFoldResult>({step->getResult (), step->getResult ()}),
178
+ ValueRange (), std::nullopt);
179
+ // Create a user of the induction variable. Bitcast is chosen for simplicity
180
+ // since it is unary.
181
+ b.setInsertionPointToStart (forallOp.getBody ());
182
+ b.create <arith::BitcastOp>(UnknownLoc::get (&context), b.getF64Type (),
183
+ forallOp.getInductionVar (0 ));
184
+ IRRewriter rewriter (b);
185
+ FailureOr<scf::ForallOp> maybeNormalizedForallOp =
186
+ normalizeForallOp (rewriter, forallOp);
187
+ EXPECT_TRUE (succeeded (maybeNormalizedForallOp));
188
+ OwningOpRef<scf::ForallOp> normalizedForallOp (*maybeNormalizedForallOp);
189
+ checkNormalized (normalizedForallOp.get ());
190
+
191
+ // Check that the IV user has been updated to use the denormalized variable.
192
+ Block *body = normalizedForallOp->getBody ();
193
+ auto bitcastOps = body->getOps <arith::BitcastOp>();
194
+ ASSERT_EQ (std::distance (bitcastOps.begin (), bitcastOps.end ()), 1 );
195
+ arith::BitcastOp ivUser = *bitcastOps.begin ();
196
+ ASSERT_NE (ivUser.getIn (), normalizedForallOp->getInductionVar (0 ));
197
+ }
0 commit comments