Skip to content

Commit 7d5d27e

Browse files
committed
[CIR] Upstream support for range-based for loops
This upstreams the code needed to handle CXXForRangeStmt.
1 parent 74f55c7 commit 7d5d27e

File tree

5 files changed

+234
-1
lines changed

5 files changed

+234
-1
lines changed

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -948,6 +948,41 @@ void CIRGenFunction::emitIgnoredExpr(const Expr *e) {
948948
emitLValue(e);
949949
}
950950

951+
Address CIRGenFunction::emitArrayToPointerDecay(const Expr *e) {
952+
assert(e->getType()->isArrayType() &&
953+
"Array to pointer decay must have array source type!");
954+
955+
// Expressions of array type can't be bitfields or vector elements.
956+
LValue lv = emitLValue(e);
957+
Address addr = lv.getAddress();
958+
959+
// If the array type was an incomplete type, we need to make sure
960+
// the decay ends up being the right type.
961+
auto lvalueAddrTy = mlir::cast<cir::PointerType>(addr.getPointer().getType());
962+
963+
if (e->getType()->isVariableArrayType())
964+
return addr;
965+
966+
auto pointeeTy = mlir::cast<cir::ArrayType>(lvalueAddrTy.getPointee());
967+
968+
mlir::Type arrayTy = convertType(e->getType());
969+
assert(mlir::isa<cir::ArrayType>(arrayTy) && "expected array");
970+
assert(pointeeTy == arrayTy);
971+
972+
// The result of this decay conversion points to an array element within the
973+
// base lvalue. However, since TBAA currently does not support representing
974+
// accesses to elements of member arrays, we conservatively represent accesses
975+
// to the pointee object as if it had no any base lvalue specified.
976+
// TODO: Support TBAA for member arrays.
977+
QualType eltType = e->getType()->castAsArrayTypeUnsafe()->getElementType();
978+
assert(!cir::MissingFeatures::opTBAA());
979+
980+
mlir::Value ptr = builder.maybeBuildArrayDecay(
981+
cgm.getLoc(e->getSourceRange()), addr.getPointer(),
982+
convertTypeForMem(eltType));
983+
return Address(ptr, addr.getAlignment());
984+
}
985+
951986
/// Emit an `if` on a boolean condition, filling `then` and `else` into
952987
/// appropriated regions.
953988
mlir::LogicalResult CIRGenFunction::emitIfOnBoolExpr(const Expr *cond,

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1567,6 +1567,9 @@ mlir::Value ScalarExprEmitter::VisitCastExpr(CastExpr *ce) {
15671567
return v;
15681568
}
15691569

1570+
case CK_ArrayToPointerDecay:
1571+
return cgf.emitArrayToPointerDecay(subExpr).getPointer();
1572+
15701573
case CK_NullToPointer: {
15711574
if (mustVisitNullValue(subExpr))
15721575
cgf.emitIgnoredExpr(subExpr);

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,8 @@ class CIRGenFunction : public CIRGenTypeCache {
449449

450450
LValue emitArraySubscriptExpr(const clang::ArraySubscriptExpr *e);
451451

452+
Address emitArrayToPointerDecay(const Expr *array);
453+
452454
AutoVarEmission emitAutoVarAlloca(const clang::VarDecl &d);
453455

454456
/// Emit code and set up symbol table for a variable declaration with auto,
@@ -485,6 +487,10 @@ class CIRGenFunction : public CIRGenTypeCache {
485487
LValue emitCompoundAssignmentLValue(const clang::CompoundAssignOperator *e);
486488

487489
mlir::LogicalResult emitContinueStmt(const clang::ContinueStmt &s);
490+
491+
mlir::LogicalResult emitCXXForRangeStmt(const CXXForRangeStmt &s,
492+
llvm::ArrayRef<const Attr *> attrs);
493+
488494
mlir::LogicalResult emitDoStmt(const clang::DoStmt &s);
489495

490496
/// Emit an expression as an initializer for an object (variable, field, etc.)

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s,
9797
return emitWhileStmt(cast<WhileStmt>(*s));
9898
case Stmt::DoStmtClass:
9999
return emitDoStmt(cast<DoStmt>(*s));
100+
case Stmt::CXXForRangeStmtClass:
101+
return emitCXXForRangeStmt(cast<CXXForRangeStmt>(*s), attr);
100102
case Stmt::OpenACCComputeConstructClass:
101103
return emitOpenACCComputeConstruct(cast<OpenACCComputeConstruct>(*s));
102104
case Stmt::OpenACCLoopConstructClass:
@@ -137,7 +139,6 @@ mlir::LogicalResult CIRGenFunction::emitStmt(const Stmt *s,
137139
case Stmt::CoroutineBodyStmtClass:
138140
case Stmt::CoreturnStmtClass:
139141
case Stmt::CXXTryStmtClass:
140-
case Stmt::CXXForRangeStmtClass:
141142
case Stmt::IndirectGotoStmtClass:
142143
case Stmt::GCCAsmStmtClass:
143144
case Stmt::MSAsmStmtClass:
@@ -547,6 +548,83 @@ mlir::LogicalResult CIRGenFunction::emitSwitchCase(const SwitchCase &s,
547548
llvm_unreachable("expect case or default stmt");
548549
}
549550

551+
mlir::LogicalResult
552+
CIRGenFunction::emitCXXForRangeStmt(const CXXForRangeStmt &s,
553+
ArrayRef<const Attr *> forAttrs) {
554+
cir::ForOp forOp;
555+
556+
// TODO(cir): pass in array of attributes.
557+
auto forStmtBuilder = [&]() -> mlir::LogicalResult {
558+
mlir::LogicalResult loopRes = mlir::success();
559+
// Evaluate the first pieces before the loop.
560+
if (s.getInit())
561+
if (emitStmt(s.getInit(), /*useCurrentScope=*/true).failed())
562+
return mlir::failure();
563+
if (emitStmt(s.getRangeStmt(), /*useCurrentScope=*/true).failed())
564+
return mlir::failure();
565+
if (emitStmt(s.getBeginStmt(), /*useCurrentScope=*/true).failed())
566+
return mlir::failure();
567+
if (emitStmt(s.getEndStmt(), /*useCurrentScope=*/true).failed())
568+
return mlir::failure();
569+
570+
assert(!cir::MissingFeatures::loopInfoStack());
571+
// From LLVM: if there are any cleanups between here and the loop-exit
572+
// scope, create a block to stage a loop exit along.
573+
// We probably already do the right thing because of ScopeOp, but make
574+
// sure we handle all cases.
575+
assert(!cir::MissingFeatures::requiresCleanups());
576+
577+
forOp = builder.createFor(
578+
getLoc(s.getSourceRange()),
579+
/*condBuilder=*/
580+
[&](mlir::OpBuilder &b, mlir::Location loc) {
581+
assert(!cir::MissingFeatures::createProfileWeightsForLoop());
582+
assert(!cir::MissingFeatures::emitCondLikelihoodViaExpectIntrinsic());
583+
mlir::Value condVal = evaluateExprAsBool(s.getCond());
584+
builder.createCondition(condVal);
585+
},
586+
/*bodyBuilder=*/
587+
[&](mlir::OpBuilder &b, mlir::Location loc) {
588+
// https://en.cppreference.com/w/cpp/language/for
589+
// In C++ the scope of the init-statement and the scope of
590+
// statement are one and the same.
591+
bool useCurrentScope = true;
592+
if (emitStmt(s.getLoopVarStmt(), useCurrentScope).failed())
593+
loopRes = mlir::failure();
594+
if (emitStmt(s.getBody(), useCurrentScope).failed())
595+
loopRes = mlir::failure();
596+
emitStopPoint(&s);
597+
},
598+
/*stepBuilder=*/
599+
[&](mlir::OpBuilder &b, mlir::Location loc) {
600+
if (s.getInc())
601+
if (emitStmt(s.getInc(), /*useCurrentScope=*/true).failed())
602+
loopRes = mlir::failure();
603+
builder.createYield(loc);
604+
});
605+
return loopRes;
606+
};
607+
608+
mlir::LogicalResult res = mlir::success();
609+
mlir::Location scopeLoc = getLoc(s.getSourceRange());
610+
builder.create<cir::ScopeOp>(scopeLoc, /*scopeBuilder=*/
611+
[&](mlir::OpBuilder &b, mlir::Location loc) {
612+
// Create a cleanup scope for the condition
613+
// variable cleanups. Logical equivalent from
614+
// LLVM codegn for LexicalScope
615+
// ConditionScope(*this, S.getSourceRange())...
616+
LexicalScope lexScope{
617+
*this, loc, builder.getInsertionBlock()};
618+
res = forStmtBuilder();
619+
});
620+
621+
if (res.failed())
622+
return res;
623+
624+
terminateBody(builder, forOp.getBody(), getLoc(s.getEndLoc()));
625+
return mlir::success();
626+
}
627+
550628
mlir::LogicalResult CIRGenFunction::emitForStmt(const ForStmt &s) {
551629
cir::ForOp forOp;
552630

clang/test/CIR/CodeGen/loop.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,117 @@ void l3() {
190190
// OGCG: store i32 0, ptr %[[I]], align 4
191191
// OGCG: br label %[[FOR_COND]]
192192

193+
void l4() {
194+
int a[10];
195+
for (int n : a)
196+
;
197+
}
198+
199+
// CIR: cir.func @_Z2l4v
200+
// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.array<!s32i x 10>, !cir.ptr<!cir.array<!s32i x 10>>, ["a"] {alignment = 16 : i64}
201+
// CIR: cir.scope {
202+
// CIR: %[[RANGE_ADDR:.*]] = cir.alloca !cir.ptr<!cir.array<!s32i x 10>>, !cir.ptr<!cir.ptr<!cir.array<!s32i x 10>>>, ["__range1", init, const] {alignment = 8 : i64}
203+
// CIR: %[[BEGIN_ADDR:.*]] = cir.alloca !cir.ptr<!s32i>, !cir.ptr<!cir.ptr<!s32i>>, ["__begin1", init] {alignment = 8 : i64}
204+
// CIR: %[[END_ADDR:.*]] = cir.alloca !cir.ptr<!s32i>, !cir.ptr<!cir.ptr<!s32i>>, ["__end1", init] {alignment = 8 : i64}
205+
// CIR: %[[N_ADDR:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["n", init] {alignment = 4 : i64}
206+
// CIR: cir.store %[[A_ADDR]], %[[RANGE_ADDR]] : !cir.ptr<!cir.array<!s32i x 10>>, !cir.ptr<!cir.ptr<!cir.array<!s32i x 10>>>
207+
// CIR: %[[RANGE_LOAD:.*]] = cir.load %[[RANGE_ADDR]] : !cir.ptr<!cir.ptr<!cir.array<!s32i x 10>>>, !cir.ptr<!cir.array<!s32i x 10>>
208+
// CIR: %[[RANGE_CAST:.*]] = cir.cast(array_to_ptrdecay, %[[RANGE_LOAD]] : !cir.ptr<!cir.array<!s32i x 10>>), !cir.ptr<!s32i>
209+
// CIR: cir.store %[[RANGE_CAST]], %[[BEGIN_ADDR]] : !cir.ptr<!s32i>, !cir.ptr<!cir.ptr<!s32i>>
210+
// CIR: %[[BEGIN:.*]] = cir.load %[[RANGE_ADDR]] : !cir.ptr<!cir.ptr<!cir.array<!s32i x 10>>>, !cir.ptr<!cir.array<!s32i x 10>>
211+
// CIR: %[[BEGIN_CAST:.*]] = cir.cast(array_to_ptrdecay, %[[BEGIN]] : !cir.ptr<!cir.array<!s32i x 10>>), !cir.ptr<!s32i>
212+
// CIR: %[[TEN:.*]] = cir.const #cir.int<10> : !s64i
213+
// CIR: %[[END_PTR:.*]] = cir.ptr_stride(%[[BEGIN_CAST]] : !cir.ptr<!s32i>, %[[TEN]] : !s64i), !cir.ptr<!s32i>
214+
// CIR: cir.store %[[END_PTR]], %[[END_ADDR]] : !cir.ptr<!s32i>, !cir.ptr<!cir.ptr<!s32i>>
215+
// CIR: cir.for : cond {
216+
// CIR: %[[CUR:.*]] = cir.load %[[BEGIN_ADDR]] : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
217+
// CIR: %[[END:.*]] = cir.load %[[END_ADDR]] : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
218+
// CIR: %[[CMP:.*]] = cir.cmp(ne, %[[CUR]], %[[END]]) : !cir.ptr<!s32i>, !cir.bool
219+
// CIR: cir.condition(%[[CMP]])
220+
// CIR: } body {
221+
// CIR: %[[CUR:.*]] = cir.load deref %[[BEGIN_ADDR]] : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
222+
// CIR: %[[N:.*]] = cir.load %[[CUR]] : !cir.ptr<!s32i>, !s32i
223+
// CIR: cir.store %[[N]], %[[N_ADDR]] : !s32i, !cir.ptr<!s32i>
224+
// CIR: cir.yield
225+
// CIR: } step {
226+
// CIR: %[[CUR:.*]] = cir.load %[[BEGIN_ADDR]] : !cir.ptr<!cir.ptr<!s32i>>, !cir.ptr<!s32i>
227+
// CIR: %[[ONE:.*]] = cir.const #cir.int<1> : !s32i
228+
// CIR: %[[NEXT:.*]] = cir.ptr_stride(%[[CUR]] : !cir.ptr<!s32i>, %[[ONE]] : !s32i), !cir.ptr<!s32i>
229+
// CIR: cir.store %[[NEXT]], %[[BEGIN_ADDR]] : !cir.ptr<!s32i>, !cir.ptr<!cir.ptr<!s32i>>
230+
// CIR: cir.yield
231+
// CIR: }
232+
// CIR: }
233+
234+
// LLVM: define void @_Z2l4v() {
235+
// LLVM: %[[RANGE_ADDR:.*]] = alloca ptr, i64 1, align 8
236+
// LLVM: %[[BEGIN_ADDR:.*]] = alloca ptr, i64 1, align 8
237+
// LLVM: %[[END_ADDR:.*]] = alloca ptr, i64 1, align 8
238+
// LLVM: %[[N_ADDR:.*]] = alloca i32, i64 1, align 4
239+
// LLVM: %[[A_ADDR:.*]] = alloca [10 x i32], i64 1, align 16
240+
// LLVM: br label %[[SETUP:.*]]
241+
// LLVM: [[SETUP]]:
242+
// LLVM: store ptr %[[A_ADDR]], ptr %[[RANGE_ADDR]], align 8
243+
// LLVM: %[[BEGIN:.*]] = load ptr, ptr %[[RANGE_ADDR]], align 8
244+
// LLVM: %[[BEGIN_CAST:.*]] = getelementptr i32, ptr %[[BEGIN]], i32 0
245+
// LLVM: store ptr %[[BEGIN_CAST]], ptr %[[BEGIN_ADDR]], align 8
246+
// LLVM: %[[RANGE:.*]] = load ptr, ptr %[[RANGE_ADDR]], align 8
247+
// LLVM: %[[RANGE_CAST:.*]] = getelementptr i32, ptr %[[RANGE]], i32 0
248+
// LLVM: %[[END_PTR:.*]] = getelementptr i32, ptr %[[RANGE_CAST]], i64 10
249+
// LLVM: store ptr %[[END_PTR]], ptr %[[END_ADDR]], align 8
250+
// LLVM: br label %[[COND:.*]]
251+
// LLVM: [[COND]]:
252+
// LLVM: %[[BEGIN:.*]] = load ptr, ptr %[[BEGIN_ADDR]], align 8
253+
// LLVM: %[[END:.*]] = load ptr, ptr %[[END_ADDR]], align 8
254+
// LLVM: %[[CMP:.*]] = icmp ne ptr %[[BEGIN]], %[[END]]
255+
// LLVM: br i1 %[[CMP]], label %[[BODY:.*]], label %[[END:.*]]
256+
// LLVM: [[BODY]]:
257+
// LLVM: %[[CUR:.*]] = load ptr, ptr %[[BEGIN_ADDR]], align 8
258+
// LLVM: %[[A_CUR:.*]] = load i32, ptr %[[CUR]], align 4
259+
// LLVM: store i32 %[[A_CUR]], ptr %[[N_ADDR]], align 4
260+
// LLVM: br label %[[STEP:.*]]
261+
// LLVM: [[STEP]]:
262+
// LLVM: %[[BEGIN:.*]] = load ptr, ptr %[[BEGIN_ADDR]], align 8
263+
// LLVM: %[[NEXT:.*]] = getelementptr i32, ptr %[[BEGIN]], i64 1
264+
// LLVM: store ptr %[[NEXT]], ptr %[[BEGIN_ADDR]], align 8
265+
// LLVM: br label %[[COND]]
266+
// LLVM: [[END]]:
267+
// LLVM: br label %[[EXIT:.*]]
268+
// LLVM: [[EXIT]]:
269+
// LLVM: ret void
270+
271+
// OGCG: define{{.*}} void @_Z2l4v()
272+
// OGCG: %[[A_ADDR:.*]] = alloca [10 x i32], align 16
273+
// OGCG: %[[RANGE_ADDR:.*]] = alloca ptr, align 8
274+
// OGCG: %[[BEGIN_ADDR:.*]] = alloca ptr, align 8
275+
// OGCG: %[[END_ADDR:.*]] = alloca ptr, align 8
276+
// OGCG: %[[N_ADDR:.*]] = alloca i32, align 4
277+
// OGCG: store ptr %[[A_ADDR]], ptr %[[RANGE_ADDR]], align 8
278+
// OGCG: %[[BEGIN:.*]] = load ptr, ptr %[[RANGE_ADDR]], align 8
279+
// OGCG: %[[BEGIN_CAST:.*]] = getelementptr inbounds [10 x i32], ptr %[[BEGIN]], i64 0, i64 0
280+
// OGCG: store ptr %[[BEGIN_CAST]], ptr %[[BEGIN_ADDR]], align 8
281+
// OGCG: %[[RANGE:.*]] = load ptr, ptr %[[RANGE_ADDR]], align 8
282+
// OGCG: %[[RANGE_CAST:.*]] = getelementptr inbounds [10 x i32], ptr %[[RANGE]], i64 0, i64 0
283+
// OGCG: %[[END_PTR:.*]] = getelementptr inbounds i32, ptr %[[RANGE_CAST]], i64 10
284+
// OGCG: store ptr %[[END_PTR]], ptr %[[END_ADDR]], align 8
285+
// OGCG: br label %[[COND:.*]]
286+
// OGCG: [[COND]]:
287+
// OGCG: %[[BEGIN:.*]] = load ptr, ptr %[[BEGIN_ADDR]], align 8
288+
// OGCG: %[[END:.*]] = load ptr, ptr %[[END_ADDR]], align 8
289+
// OGCG: %[[CMP:.*]] = icmp ne ptr %[[BEGIN]], %[[END]]
290+
// OGCG: br i1 %[[CMP]], label %[[BODY:.*]], label %[[END:.*]]
291+
// OGCG: [[BODY]]:
292+
// OGCG: %[[CUR:.*]] = load ptr, ptr %[[BEGIN_ADDR]], align 8
293+
// OGCG: %[[A_CUR:.*]] = load i32, ptr %[[CUR]], align 4
294+
// OGCG: store i32 %[[A_CUR]], ptr %[[N_ADDR]], align 4
295+
// OGCG: br label %[[STEP:.*]]
296+
// OGCG: [[STEP]]:
297+
// OGCG: %[[BEGIN:.*]] = load ptr, ptr %[[BEGIN_ADDR]], align 8
298+
// OGCG: %[[NEXT:.*]] = getelementptr inbounds nuw i32, ptr %[[BEGIN]], i32 1
299+
// OGCG: store ptr %[[NEXT]], ptr %[[BEGIN_ADDR]], align 8
300+
// OGCG: br label %[[COND]]
301+
// OGCG: [[END]]:
302+
// OGCG: ret void
303+
193304
void test_do_while_false() {
194305
do {
195306
} while (0);

0 commit comments

Comments
 (0)