Skip to content

Commit 343319a

Browse files
terapines open source contributor 2AdUhTkJm
authored andcommitted
[CIR][ThroughMLIR] Fix ForOp handling (llvm#1615)
Currently the ForOp handling ignores everything except load, store and arithmetic in the step region. It does not detect whether the step and induction variable has already been assigned, either. That might result to wrong behaviour: ```cpp // Ignores printf for (int i = 0; i < n; i++, printf("\n")); // Only increments once for (int i = 0; i < n; i++, i++); ``` I choose to rewrite the detection and do an exact match of the instruction sequence for `i++` and `i += n`. It doesn't seem easy to detect a more general pattern without phi nodes. The new test case is xfailed, because ForOp hits an unreachable when it meets a non-canonicalized loop. We can implement that functionality later. Co-authored-by: Yue Huang <[email protected]>
1 parent 4784d3b commit 343319a

File tree

3 files changed

+120
-43
lines changed

3 files changed

+120
-43
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 106 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,24 @@ class SCFLoop {
4242
int64_t getStep() { return step; }
4343
mlir::Value getLowerBound() { return lowerBound; }
4444
mlir::Value getUpperBound() { return upperBound; }
45+
bool isCanonical() { return canonical; }
4546

46-
int64_t findStepAndIV(mlir::Value &addr);
47+
// Returns true if successfully finds both step and induction variable.
48+
mlir::LogicalResult findStepAndIV();
4749
cir::CmpOp findCmpOp();
4850
mlir::Value findIVInitValue();
4951
void analysis();
5052

51-
mlir::Value plusConstant(mlir::Value V, mlir::Location loc, int addend);
53+
mlir::Value plusConstant(mlir::Value v, mlir::Location loc, int addend);
5254
void transferToSCFForOp();
5355

5456
private:
5557
cir::ForOp forOp;
5658
cir::CmpOp cmpOp;
57-
mlir::Value IVAddr, lowerBound = nullptr, upperBound = nullptr;
59+
mlir::Value ivAddr, lowerBound = nullptr, upperBound = nullptr;
5860
mlir::ConversionPatternRewriter *rewriter;
5961
int64_t step = 0;
62+
bool canonical = true;
6063
};
6164

6265
class SCFWhileLoop {
@@ -86,47 +89,97 @@ class SCFDoLoop {
8689
};
8790

8891
static int64_t getConstant(cir::ConstantOp op) {
89-
auto attr = op->getAttrs().front().getValue();
90-
const auto IntAttr = mlir::dyn_cast<cir::IntAttr>(attr);
91-
return IntAttr.getValue().getSExtValue();
92+
auto attr = op.getValue();
93+
const auto intAttr = mlir::cast<cir::IntAttr>(attr);
94+
return intAttr.getValue().getSExtValue();
9295
}
9396

94-
int64_t SCFLoop::findStepAndIV(mlir::Value &addr) {
97+
mlir::LogicalResult SCFLoop::findStepAndIV() {
9598
auto *stepBlock =
9699
(forOp.maybeGetStep() ? &forOp.maybeGetStep()->front() : nullptr);
97100
assert(stepBlock && "Can not find step block");
98101

99-
int64_t step = 0;
100-
mlir::Value IV = nullptr;
101-
// Try to match "IV load addr; ++IV; store IV, addr" to find step.
102-
for (mlir::Operation &op : *stepBlock)
103-
if (auto loadOp = dyn_cast<cir::LoadOp>(op)) {
104-
addr = loadOp.getAddr();
105-
IV = loadOp.getResult();
106-
} else if (auto cop = dyn_cast<cir::ConstantOp>(op)) {
107-
if (step)
108-
llvm_unreachable(
109-
"Not support multiple constant in step calculation yet");
110-
step = getConstant(cop);
111-
} else if (auto bop = dyn_cast<cir::BinOp>(op)) {
112-
if (bop.getLhs() != IV)
113-
llvm_unreachable("Find BinOp not operate on IV");
114-
if (bop.getKind() != cir::BinOpKind::Add)
115-
llvm_unreachable(
116-
"Not support BinOp other than Add in step calculation yet");
117-
} else if (auto uop = dyn_cast<cir::UnaryOp>(op)) {
118-
if (uop.getInput() != IV)
119-
llvm_unreachable("Find UnaryOp not operate on IV");
120-
if (uop.getKind() == cir::UnaryOpKind::Inc)
121-
step = 1;
122-
else if (uop.getKind() == cir::UnaryOpKind::Dec)
123-
llvm_unreachable("Not support decrement step yet");
124-
} else if (auto storeOp = dyn_cast<cir::StoreOp>(op)) {
125-
assert(storeOp.getAddr() == addr && "Can't find IV when lowering ForOp");
126-
}
127-
assert(step && "Can't find step when lowering ForOp");
102+
// Try to match "iv = load addr; ++iv; store iv, addr; yield" to find step.
103+
// We should match the exact pattern, in case there's something unexpected:
104+
// we must rule out cases like `for (int i = 0; i < n; i++, printf("\n"))`.
105+
auto &oplist = stepBlock->getOperations();
106+
107+
auto iterator = oplist.begin();
108+
109+
// We might find constants at beginning. Skip them.
110+
// We could have hoisted them outside the for loop in previous passes, but
111+
// it hasn't been done yet.
112+
while (iterator != oplist.end() && isa<ConstantOp>(*iterator))
113+
++iterator;
114+
115+
if (iterator == oplist.end())
116+
return mlir::failure();
117+
118+
auto load = dyn_cast<LoadOp>(*iterator);
119+
if (!load)
120+
return mlir::failure();
121+
122+
// We assume this is the address of induction variable (IV). The operations
123+
// that come next will check if that's true.
124+
mlir::Value addr = load.getAddr();
125+
mlir::Value iv = load.getResult();
126+
127+
// Then we try to match either "++IV" or "IV += n". Same for reversed loops.
128+
if (++iterator == oplist.end())
129+
return mlir::failure();
130+
131+
mlir::Operation &arith = *iterator;
132+
133+
if (auto unary = dyn_cast<UnaryOp>(arith)) {
134+
// Not operating on induction variable. Fail.
135+
if (unary.getInput() != iv)
136+
return mlir::failure();
137+
138+
if (unary.getKind() == UnaryOpKind::Inc)
139+
step = 1;
140+
else if (unary.getKind() == UnaryOpKind::Dec)
141+
step = -1;
142+
else
143+
return mlir::failure();
144+
}
145+
146+
if (auto binary = dyn_cast<BinOp>(arith)) {
147+
if (binary.getLhs() != iv)
148+
return mlir::failure();
149+
150+
mlir::Value value = binary.getRhs();
151+
if (auto constValue = dyn_cast<ConstantOp>(value.getDefiningOp());
152+
isa<IntAttr>(constValue.getValue()))
153+
step = getConstant(constValue);
154+
155+
if (binary.getKind() == BinOpKind::Add)
156+
; // Nothing to do. Step has been calculated above.
157+
else if (binary.getKind() == BinOpKind::Sub)
158+
step = -step;
159+
else
160+
return mlir::failure();
161+
}
162+
163+
// Check whether we immediately store this value into the appropriate place.
164+
if (++iterator == oplist.end())
165+
return mlir::failure();
166+
167+
auto store = dyn_cast<StoreOp>(*iterator);
168+
if (!store || store.getAddr() != addr ||
169+
store.getValue() != arith.getResult(0))
170+
return mlir::failure();
171+
172+
if (++iterator == oplist.end())
173+
return mlir::failure();
128174

129-
return step;
175+
// Finally, this should precede a yield with nothing in between.
176+
bool success = isa<YieldOp>(*iterator);
177+
178+
// Remember to update analysis information.
179+
if (success)
180+
ivAddr = addr;
181+
182+
return success ? mlir::success() : mlir::failure();
130183
}
131184

132185
static bool isIVLoad(mlir::Operation *op, mlir::Value IVAddr) {
@@ -143,7 +196,7 @@ static bool isIVLoad(mlir::Operation *op, mlir::Value IVAddr) {
143196

144197
cir::CmpOp SCFLoop::findCmpOp() {
145198
cmpOp = nullptr;
146-
for (auto *user : IVAddr.getUsers()) {
199+
for (auto *user : ivAddr.getUsers()) {
147200
if (user->getParentRegion() != &forOp.getCond())
148201
continue;
149202
if (auto loadOp = dyn_cast<cir::LoadOp>(*user)) {
@@ -162,10 +215,10 @@ cir::CmpOp SCFLoop::findCmpOp() {
162215
if (!mlir::isa<cir::IntType>(type))
163216
llvm_unreachable("Non-integer type IV is not supported");
164217

165-
auto lhsDefOp = cmpOp.getLhs().getDefiningOp();
218+
auto *lhsDefOp = cmpOp.getLhs().getDefiningOp();
166219
if (!lhsDefOp)
167220
llvm_unreachable("Can't find IV load");
168-
if (!isIVLoad(lhsDefOp, IVAddr))
221+
if (!isIVLoad(lhsDefOp, ivAddr))
169222
llvm_unreachable("cmpOp LHS is not IV");
170223

171224
if (cmpOp.getKind() != cir::CmpOpKind::le &&
@@ -187,7 +240,7 @@ mlir::Value SCFLoop::plusConstant(mlir::Value V, mlir::Location loc,
187240
// The operations before the loop have been transferred to MLIR.
188241
// So we need to go through getRemappedValue to find the value.
189242
mlir::Value SCFLoop::findIVInitValue() {
190-
auto remapAddr = rewriter->getRemappedValue(IVAddr);
243+
auto remapAddr = rewriter->getRemappedValue(ivAddr);
191244
if (!remapAddr)
192245
return nullptr;
193246
if (!remapAddr.hasOneUse())
@@ -199,7 +252,13 @@ mlir::Value SCFLoop::findIVInitValue() {
199252
}
200253

201254
void SCFLoop::analysis() {
202-
step = findStepAndIV(IVAddr);
255+
canonical = mlir::succeeded(findStepAndIV());
256+
if (!canonical) {
257+
mlir::emitError(forOp.getLoc(),
258+
"cannot handle non-constant step for induction variable");
259+
return;
260+
}
261+
203262
cmpOp = findCmpOp();
204263
auto IVInit = findIVInitValue();
205264
// The loop end value should be hoisted out of loop by -cir-mlir-scf-prepare.
@@ -237,7 +296,7 @@ void SCFLoop::transferToSCFForOp() {
237296
llvm_unreachable(
238297
"Not support lowering loop with break, continue or if yet");
239298
// Replace the IV usage to scf loop induction variable.
240-
if (isIVLoad(op, IVAddr)) {
299+
if (isIVLoad(op, ivAddr)) {
241300
// Replace CIR IV load with arith.addi scf.IV, 0.
242301
// The replacement makes the SCF IV can be automatically propogated
243302
// by OpAdaptor for individual IV user lowering.
@@ -293,6 +352,10 @@ class CIRForOpLowering : public mlir::OpConversionPattern<cir::ForOp> {
293352
mlir::ConversionPatternRewriter &rewriter) const override {
294353
SCFLoop loop(op, &rewriter);
295354
loop.analysis();
355+
if (!loop.isCanonical())
356+
return mlir::emitError(op.getLoc(),
357+
"cannot handle non-canonicalized loop");
358+
296359
loop.transferToSCFForOp();
297360
rewriter.eraseOp(op);
298361
return mlir::success();
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// RUN: not %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o - 2>&1 | FileCheck %s
2+
3+
void f();
4+
5+
void reject() {
6+
for (int i = 0; i < 100; i++, f());
7+
// CHECK: failed to legalize operation 'cir.for'
8+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// RUN: not %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir=core %s -o - 2>&1 | FileCheck %s
2+
3+
void reject() {
4+
for (int i = 0; i < 100; i++, i++);
5+
// CHECK: failed to legalize operation 'cir.for'
6+
}

0 commit comments

Comments
 (0)