Skip to content

Commit 7fbcef9

Browse files
authored
Add support for if and while statements that declare a condition variable. (#271)
1 parent 8a21abf commit 7fbcef9

File tree

3 files changed

+50
-4
lines changed

3 files changed

+50
-4
lines changed

tools/cgeist/Lib/CGStmt.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -758,10 +758,10 @@ ValueCategory MLIRScanner::VisitDoStmt(clang::DoStmt *fors) {
758758
return nullptr;
759759
}
760760

761-
ValueCategory MLIRScanner::VisitWhileStmt(clang::WhileStmt *fors) {
761+
ValueCategory MLIRScanner::VisitWhileStmt(clang::WhileStmt *stmt) {
762762
IfScope scope(*this);
763763

764-
auto loc = getMLIRLocation(fors->getLParenLoc());
764+
auto loc = getMLIRLocation(stmt->getLParenLoc());
765765

766766
auto i1Ty = builder.getIntegerType(1);
767767
auto type = mlir::MemRefType::get({}, i1Ty, {}, 0);
@@ -782,7 +782,9 @@ ValueCategory MLIRScanner::VisitWhileStmt(clang::WhileStmt *fors) {
782782

783783
builder.setInsertionPointToStart(&condB);
784784

785-
if (auto *s = fors->getCond()) {
785+
if (auto declStmt = stmt->getConditionVariableDeclStmt())
786+
Visit(declStmt);
787+
if (auto *s = stmt->getCond()) {
786788
auto condRes = Visit(s);
787789
auto cond = condRes.getValue(loc, builder);
788790
if (auto LT = cond.getType().dyn_cast<mlir::LLVM::LLVMPointerType>()) {
@@ -809,7 +811,7 @@ ValueCategory MLIRScanner::VisitWhileStmt(clang::WhileStmt *fors) {
809811
std::vector<mlir::Value>()),
810812
loops.back().keepRunning, std::vector<mlir::Value>());
811813

812-
Visit(fors->getBody());
814+
Visit(stmt->getBody());
813815
loops.pop_back();
814816

815817
builder.create<mlir::cf::BranchOp>(loc, &condB);
@@ -822,6 +824,8 @@ ValueCategory MLIRScanner::VisitWhileStmt(clang::WhileStmt *fors) {
822824
ValueCategory MLIRScanner::VisitIfStmt(clang::IfStmt *stmt) {
823825
IfScope scope(*this);
824826
auto loc = getMLIRLocation(stmt->getIfLoc());
827+
if (auto declStmt = stmt->getConditionVariableDeclStmt())
828+
Visit(declStmt);
825829
auto cond = Visit(stmt->getCond()).getValue(loc, builder);
826830
assert(cond != nullptr && "must be a non-null");
827831

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: cgeist %s --function=* -S | FileCheck %s
2+
3+
struct A {
4+
int value;
5+
6+
int* getPointer() {
7+
if (int* tmp = &this->value) {
8+
return tmp;
9+
}
10+
return nullptr;
11+
}
12+
};
13+
14+
int main() {
15+
return *A().getPointer();
16+
}
17+
18+
// CHECK: func.func @_ZN1A10getPointerEv(
19+
// CHECK: "polygeist.memref2pointer"
20+
// CHECK: llvm.mlir.null
21+
// CHECK: llvm.icmp "ne"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: cgeist %s --function=* -S | FileCheck %s
2+
3+
struct A {
4+
int value;
5+
6+
int getValue() {
7+
while (int tmp = this->value) {
8+
tmp++;
9+
}
10+
return value;
11+
}
12+
};
13+
14+
int main() {
15+
return A().getValue();
16+
}
17+
18+
// CHECK: func.func @_ZN1A8getValueEv(
19+
// CHECK: scf.while
20+
// CHECK: arith.cmpi ne
21+
// CHECK: scf.condition

0 commit comments

Comments
 (0)