Skip to content

Commit b5e5794

Browse files
authored
[CIR] Implement Statement Expressions (#153677)
Depends on #153625 This patch adds support for statement expressions. It also changes emitCompoundStmt and emitCompoundStmtWithoutScope to accept an Address that the optional result is written to. This allows the creation of the alloca ahead of the creation of the scope which saves us from hoisting the alloca to its parent scope.
1 parent 9adc4f9 commit b5e5794

File tree

9 files changed

+506
-45
lines changed

9 files changed

+506
-45
lines changed

clang/lib/CIR/CodeGen/CIRGenExpr.cpp

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,28 +1170,33 @@ static void pushTemporaryCleanup(CIRGenFunction &cgf,
11701170
return;
11711171
}
11721172

1173-
CXXDestructorDecl *referenceTemporaryDtor = nullptr;
1174-
if (const clang::RecordType *rt = e->getType()
1175-
->getBaseElementTypeUnsafe()
1176-
->getAs<clang::RecordType>()) {
1177-
// Get the destructor for the reference temporary.
1178-
auto *classDecl =
1179-
cast<CXXRecordDecl>(rt->getOriginalDecl()->getDefinitionOrSelf());
1180-
if (!classDecl->hasTrivialDestructor())
1181-
referenceTemporaryDtor =
1182-
classDecl->getDefinitionOrSelf()->getDestructor();
1183-
}
1184-
1185-
if (!referenceTemporaryDtor)
1173+
const QualType::DestructionKind dk = e->getType().isDestructedType();
1174+
if (dk == QualType::DK_none)
11861175
return;
11871176

1188-
// Call the destructor for the temporary.
11891177
switch (m->getStorageDuration()) {
11901178
case SD_Static:
1191-
case SD_Thread:
1192-
cgf.cgm.errorNYI(e->getSourceRange(),
1193-
"pushTemporaryCleanup: static/thread storage duration");
1194-
return;
1179+
case SD_Thread: {
1180+
CXXDestructorDecl *referenceTemporaryDtor = nullptr;
1181+
if (const clang::RecordType *rt = e->getType()
1182+
->getBaseElementTypeUnsafe()
1183+
->getAs<clang::RecordType>()) {
1184+
// Get the destructor for the reference temporary.
1185+
if (const auto *classDecl = dyn_cast<CXXRecordDecl>(
1186+
rt->getOriginalDecl()->getDefinitionOrSelf())) {
1187+
if (!classDecl->hasTrivialDestructor())
1188+
referenceTemporaryDtor =
1189+
classDecl->getDefinitionOrSelf()->getDestructor();
1190+
}
1191+
}
1192+
1193+
if (!referenceTemporaryDtor)
1194+
return;
1195+
1196+
cgf.cgm.errorNYI(e->getSourceRange(), "pushTemporaryCleanup: static/thread "
1197+
"storage duration with destructors");
1198+
break;
1199+
}
11951200

11961201
case SD_FullExpression:
11971202
cgf.pushDestroy(NormalAndEHCleanup, referenceTemporary, e->getType(),

clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> {
6969
void Visit(Expr *e) { StmtVisitor<AggExprEmitter>::Visit(e); }
7070

7171
void VisitCallExpr(const CallExpr *e);
72+
void VisitStmtExpr(const StmtExpr *e) {
73+
CIRGenFunction::StmtExprEvaluation eval(cgf);
74+
Address retAlloca =
75+
cgf.createMemTemp(e->getType(), cgf.getLoc(e->getSourceRange()));
76+
(void)cgf.emitCompoundStmt(*e->getSubStmt(), &retAlloca, dest);
77+
}
7278

7379
void VisitDeclRefExpr(DeclRefExpr *e) { emitAggLoadOfLValue(e); }
7480

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,21 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
185185
mlir::Value VisitCastExpr(CastExpr *e);
186186
mlir::Value VisitCallExpr(const CallExpr *e);
187187

188+
mlir::Value VisitStmtExpr(StmtExpr *e) {
189+
CIRGenFunction::StmtExprEvaluation eval(cgf);
190+
if (e->getType()->isVoidType()) {
191+
(void)cgf.emitCompoundStmt(*e->getSubStmt());
192+
return {};
193+
}
194+
195+
Address retAlloca =
196+
cgf.createMemTemp(e->getType(), cgf.getLoc(e->getSourceRange()));
197+
(void)cgf.emitCompoundStmt(*e->getSubStmt(), &retAlloca);
198+
199+
return cgf.emitLoadOfScalar(cgf.makeAddrLValue(retAlloca, e->getType()),
200+
e->getExprLoc());
201+
}
202+
188203
mlir::Value VisitArraySubscriptExpr(ArraySubscriptExpr *e) {
189204
if (e->getBase()->getType()->isVectorType()) {
190205
assert(!cir::MissingFeatures::scalableVectors());

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -490,13 +490,10 @@ mlir::LogicalResult CIRGenFunction::emitFunctionBody(const clang::Stmt *body) {
490490
// We start with function level scope for variables.
491491
SymTableScopeTy varScope(symbolTable);
492492

493-
auto result = mlir::LogicalResult::success();
494493
if (const CompoundStmt *block = dyn_cast<CompoundStmt>(body))
495-
emitCompoundStmtWithoutScope(*block);
496-
else
497-
result = emitStmt(body, /*useCurrentScope=*/true);
494+
return emitCompoundStmtWithoutScope(*block);
498495

499-
return result;
496+
return emitStmt(body, /*useCurrentScope=*/true);
500497
}
501498

502499
static void eraseEmptyAndUnusedBlocks(cir::FuncOp func) {
@@ -561,7 +558,6 @@ cir::FuncOp CIRGenFunction::generateCode(clang::GlobalDecl gd, cir::FuncOp fn,
561558
emitImplicitAssignmentOperatorBody(args);
562559
} else if (body) {
563560
if (mlir::failed(emitFunctionBody(body))) {
564-
fn.erase();
565561
return nullptr;
566562
}
567563
} else {

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,9 +1174,14 @@ class CIRGenFunction : public CIRGenTypeCache {
11741174
LValue emitScalarCompoundAssignWithComplex(const CompoundAssignOperator *e,
11751175
mlir::Value &result);
11761176

1177-
void emitCompoundStmt(const clang::CompoundStmt &s);
1177+
mlir::LogicalResult
1178+
emitCompoundStmt(const clang::CompoundStmt &s, Address *lastValue = nullptr,
1179+
AggValueSlot slot = AggValueSlot::ignored());
11781180

1179-
void emitCompoundStmtWithoutScope(const clang::CompoundStmt &s);
1181+
mlir::LogicalResult
1182+
emitCompoundStmtWithoutScope(const clang::CompoundStmt &s,
1183+
Address *lastValue = nullptr,
1184+
AggValueSlot slot = AggValueSlot::ignored());
11801185

11811186
void emitDecl(const clang::Decl &d, bool evaluateConditionDecl = false);
11821187
mlir::LogicalResult emitDeclStmt(const clang::DeclStmt &s);
@@ -1413,6 +1418,27 @@ class CIRGenFunction : public CIRGenTypeCache {
14131418
// we know if a temporary should be destroyed conditionally.
14141419
ConditionalEvaluation *outermostConditional = nullptr;
14151420

1421+
/// An RAII object to record that we're evaluating a statement
1422+
/// expression.
1423+
class StmtExprEvaluation {
1424+
CIRGenFunction &cgf;
1425+
1426+
/// We have to save the outermost conditional: cleanups in a
1427+
/// statement expression aren't conditional just because the
1428+
/// StmtExpr is.
1429+
ConditionalEvaluation *savedOutermostConditional;
1430+
1431+
public:
1432+
StmtExprEvaluation(CIRGenFunction &cgf)
1433+
: cgf(cgf), savedOutermostConditional(cgf.outermostConditional) {
1434+
cgf.outermostConditional = nullptr;
1435+
}
1436+
1437+
~StmtExprEvaluation() {
1438+
cgf.outermostConditional = savedOutermostConditional;
1439+
}
1440+
};
1441+
14161442
template <typename FuncTy>
14171443
ConditionalInfo emitConditionalBlocks(const AbstractConditionalOperator *e,
14181444
const FuncTy &branchGenFunc);

clang/lib/CIR/CodeGen/CIRGenStmt.cpp

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#include "CIRGenFunction.h"
1515

1616
#include "mlir/IR/Builders.h"
17+
#include "mlir/IR/Location.h"
18+
#include "mlir/Support/LLVM.h"
1719
#include "clang/AST/ExprCXX.h"
1820
#include "clang/AST/Stmt.h"
1921
#include "clang/AST/StmtOpenACC.h"
@@ -23,16 +25,68 @@ using namespace clang;
2325
using namespace clang::CIRGen;
2426
using namespace cir;
2527

26-
void CIRGenFunction::emitCompoundStmtWithoutScope(const CompoundStmt &s) {
27-
for (auto *curStmt : s.body()) {
28-
if (emitStmt(curStmt, /*useCurrentScope=*/false).failed())
29-
getCIRGenModule().errorNYI(curStmt->getSourceRange(),
30-
std::string("emitCompoundStmtWithoutScope: ") +
31-
curStmt->getStmtClassName());
28+
static mlir::LogicalResult emitStmtWithResult(CIRGenFunction &cgf,
29+
const Stmt *exprResult,
30+
AggValueSlot slot,
31+
Address *lastValue) {
32+
// We have to special case labels here. They are statements, but when put
33+
// at the end of a statement expression, they yield the value of their
34+
// subexpression. Handle this by walking through all labels we encounter,
35+
// emitting them before we evaluate the subexpr.
36+
// Similar issues arise for attributed statements.
37+
while (!isa<Expr>(exprResult)) {
38+
if (const auto *ls = dyn_cast<LabelStmt>(exprResult)) {
39+
if (cgf.emitLabel(*ls->getDecl()).failed())
40+
return mlir::failure();
41+
exprResult = ls->getSubStmt();
42+
} else if (const auto *as = dyn_cast<AttributedStmt>(exprResult)) {
43+
// FIXME: Update this if we ever have attributes that affect the
44+
// semantics of an expression.
45+
exprResult = as->getSubStmt();
46+
} else {
47+
llvm_unreachable("Unknown value statement");
48+
}
49+
}
50+
51+
const Expr *e = cast<Expr>(exprResult);
52+
QualType exprTy = e->getType();
53+
if (cgf.hasAggregateEvaluationKind(exprTy)) {
54+
cgf.emitAggExpr(e, slot);
55+
} else {
56+
// We can't return an RValue here because there might be cleanups at
57+
// the end of the StmtExpr. Because of that, we have to emit the result
58+
// here into a temporary alloca.
59+
cgf.emitAnyExprToMem(e, *lastValue, Qualifiers(),
60+
/*IsInit*/ false);
61+
}
62+
63+
return mlir::success();
64+
}
65+
66+
mlir::LogicalResult CIRGenFunction::emitCompoundStmtWithoutScope(
67+
const CompoundStmt &s, Address *lastValue, AggValueSlot slot) {
68+
mlir::LogicalResult result = mlir::success();
69+
const Stmt *exprResult = s.getStmtExprResult();
70+
assert((!lastValue || (lastValue && exprResult)) &&
71+
"If lastValue is not null then the CompoundStmt must have a "
72+
"StmtExprResult");
73+
74+
for (const Stmt *curStmt : s.body()) {
75+
const bool saveResult = lastValue && exprResult == curStmt;
76+
if (saveResult) {
77+
if (emitStmtWithResult(*this, exprResult, slot, lastValue).failed())
78+
result = mlir::failure();
79+
} else {
80+
if (emitStmt(curStmt, /*useCurrentScope=*/false).failed())
81+
result = mlir::failure();
82+
}
3283
}
84+
return result;
3385
}
3486

35-
void CIRGenFunction::emitCompoundStmt(const CompoundStmt &s) {
87+
mlir::LogicalResult CIRGenFunction::emitCompoundStmt(const CompoundStmt &s,
88+
Address *lastValue,
89+
AggValueSlot slot) {
3690
// Add local scope to track new declared variables.
3791
SymTableScopeTy varScope(symbolTable);
3892
mlir::Location scopeLoc = getLoc(s.getSourceRange());
@@ -41,12 +95,10 @@ void CIRGenFunction::emitCompoundStmt(const CompoundStmt &s) {
4195
scopeLoc, [&](mlir::OpBuilder &b, mlir::Type &type, mlir::Location loc) {
4296
scopeInsPt = b.saveInsertionPoint();
4397
});
44-
{
45-
mlir::OpBuilder::InsertionGuard guard(builder);
46-
builder.restoreInsertionPoint(scopeInsPt);
47-
LexicalScope lexScope(*this, scopeLoc, builder.getInsertionBlock());
48-
emitCompoundStmtWithoutScope(s);
49-
}
98+
mlir::OpBuilder::InsertionGuard guard(builder);
99+
builder.restoreInsertionPoint(scopeInsPt);
100+
LexicalScope lexScope(*this, scopeLoc, builder.getInsertionBlock());
101+
return emitCompoundStmtWithoutScope(s, lastValue, slot);
50102
}
51103

52104
void CIRGenFunction::emitStopPoint(const Stmt *s) {
@@ -249,10 +301,8 @@ mlir::LogicalResult CIRGenFunction::emitSimpleStmt(const Stmt *s,
249301
return emitDeclStmt(cast<DeclStmt>(*s));
250302
case Stmt::CompoundStmtClass:
251303
if (useCurrentScope)
252-
emitCompoundStmtWithoutScope(cast<CompoundStmt>(*s));
253-
else
254-
emitCompoundStmt(cast<CompoundStmt>(*s));
255-
break;
304+
return emitCompoundStmtWithoutScope(cast<CompoundStmt>(*s));
305+
return emitCompoundStmt(cast<CompoundStmt>(*s));
256306
case Stmt::GotoStmtClass:
257307
return emitGotoStmt(cast<GotoStmt>(*s));
258308
case Stmt::ContinueStmtClass:

0 commit comments

Comments
 (0)