Skip to content

Commit cf9cb54

Browse files
[CIR] Emit promise declaration in coroutine (#166683)
This PR adds support for emitting the promise declaration in coroutines and obtaining the `get_return_object()`.
1 parent 260df80 commit cf9cb54

File tree

5 files changed

+145
-3
lines changed

5 files changed

+145
-3
lines changed

clang/include/clang/CIR/MissingFeatures.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ struct MissingFeatures {
153153
static bool coroEndBuiltinCall() { return false; }
154154
static bool coroutineFrame() { return false; }
155155
static bool emitBodyAndFallthrough() { return false; }
156+
static bool coroOutsideFrameMD() { return false; }
156157

157158
// Various handling of deferred processing in CIRGenModule.
158159
static bool cgmRelease() { return false; }

clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "CIRGenFunction.h"
1414
#include "mlir/Support/LLVM.h"
1515
#include "clang/AST/StmtCXX.h"
16+
#include "clang/AST/StmtVisitor.h"
1617
#include "clang/Basic/TargetInfo.h"
1718
#include "clang/CIR/Dialect/IR/CIRTypes.h"
1819
#include "clang/CIR/MissingFeatures.h"
@@ -33,6 +34,65 @@ struct clang::CIRGen::CGCoroData {
3334
CIRGenFunction::CGCoroInfo::CGCoroInfo() {}
3435
CIRGenFunction::CGCoroInfo::~CGCoroInfo() {}
3536

37+
namespace {
38+
// FIXME: both GetParamRef and ParamReferenceReplacerRAII are good template
39+
// candidates to be shared among LLVM / CIR codegen.
40+
41+
// Hunts for the parameter reference in the parameter copy/move declaration.
42+
struct GetParamRef : public StmtVisitor<GetParamRef> {
43+
public:
44+
DeclRefExpr *expr = nullptr;
45+
GetParamRef() {}
46+
void VisitDeclRefExpr(DeclRefExpr *e) {
47+
assert(expr == nullptr && "multilple declref in param move");
48+
expr = e;
49+
}
50+
void VisitStmt(Stmt *s) {
51+
for (Stmt *c : s->children()) {
52+
if (c)
53+
Visit(c);
54+
}
55+
}
56+
};
57+
58+
// This class replaces references to parameters to their copies by changing
59+
// the addresses in CGF.LocalDeclMap and restoring back the original values in
60+
// its destructor.
61+
struct ParamReferenceReplacerRAII {
62+
CIRGenFunction::DeclMapTy savedLocals;
63+
CIRGenFunction::DeclMapTy &localDeclMap;
64+
65+
ParamReferenceReplacerRAII(CIRGenFunction::DeclMapTy &localDeclMap)
66+
: localDeclMap(localDeclMap) {}
67+
68+
void addCopy(const DeclStmt *pm) {
69+
// Figure out what param it refers to.
70+
71+
assert(pm->isSingleDecl());
72+
const VarDecl *vd = static_cast<const VarDecl *>(pm->getSingleDecl());
73+
const Expr *initExpr = vd->getInit();
74+
GetParamRef visitor;
75+
visitor.Visit(const_cast<Expr *>(initExpr));
76+
assert(visitor.expr);
77+
DeclRefExpr *dreOrig = visitor.expr;
78+
auto *pd = dreOrig->getDecl();
79+
80+
auto it = localDeclMap.find(pd);
81+
assert(it != localDeclMap.end() && "parameter is not found");
82+
savedLocals.insert({pd, it->second});
83+
84+
auto copyIt = localDeclMap.find(vd);
85+
assert(copyIt != localDeclMap.end() && "parameter copy is not found");
86+
it->second = copyIt->getSecond();
87+
}
88+
89+
~ParamReferenceReplacerRAII() {
90+
for (auto &&savedLocal : savedLocals) {
91+
localDeclMap.insert({savedLocal.first, savedLocal.second});
92+
}
93+
}
94+
};
95+
} // namespace
3696
static void createCoroData(CIRGenFunction &cgf,
3797
CIRGenFunction::CGCoroInfo &curCoro,
3898
cir::CallOp coroId) {
@@ -149,7 +209,47 @@ CIRGenFunction::emitCoroutineBody(const CoroutineBodyStmt &s) {
149209
if (s.getReturnStmtOnAllocFailure())
150210
cgm.errorNYI("handle coroutine return alloc failure");
151211

152-
assert(!cir::MissingFeatures::generateDebugInfo());
153-
assert(!cir::MissingFeatures::emitBodyAndFallthrough());
212+
{
213+
assert(!cir::MissingFeatures::generateDebugInfo());
214+
ParamReferenceReplacerRAII paramReplacer(localDeclMap);
215+
// Create mapping between parameters and copy-params for coroutine
216+
// function.
217+
llvm::ArrayRef<const Stmt *> paramMoves = s.getParamMoves();
218+
assert((paramMoves.size() == 0 || (paramMoves.size() == fnArgs.size())) &&
219+
"ParamMoves and FnArgs should be the same size for coroutine "
220+
"function");
221+
// For zipping the arg map into debug info.
222+
assert(!cir::MissingFeatures::generateDebugInfo());
223+
224+
// Create parameter copies. We do it before creating a promise, since an
225+
// evolution of coroutine TS may allow promise constructor to observe
226+
// parameter copies.
227+
assert(!cir::MissingFeatures::coroOutsideFrameMD());
228+
for (auto *pm : paramMoves) {
229+
if (emitStmt(pm, /*useCurrentScope=*/true).failed())
230+
return mlir::failure();
231+
paramReplacer.addCopy(cast<DeclStmt>(pm));
232+
}
233+
234+
if (emitStmt(s.getPromiseDeclStmt(), /*useCurrentScope=*/true).failed())
235+
return mlir::failure();
236+
// returnValue should be valid as long as the coroutine's return type
237+
// is not void. The assertion could help us to reduce the check later.
238+
assert(returnValue.isValid() == (bool)s.getReturnStmt());
239+
// Now we have the promise, initialize the GRO.
240+
// We need to emit `get_return_object` first. According to:
241+
// [dcl.fct.def.coroutine]p7
242+
// The call to get_return_­object is sequenced before the call to
243+
// initial_suspend and is invoked at most once.
244+
//
245+
// So we couldn't emit return value when we emit return statment,
246+
// otherwise the call to get_return_object wouldn't be in front
247+
// of initial_suspend.
248+
if (returnValue.isValid())
249+
emitAnyExprToMem(s.getReturnValue(), returnValue,
250+
s.getReturnValue()->getType().getQualifiers(),
251+
/*isInit*/ true);
252+
assert(!cir::MissingFeatures::emitBodyAndFallthrough());
253+
}
154254
return mlir::success();
155255
}

clang/lib/CIR/CodeGen/CIRGenFunction.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,10 @@ cir::FuncOp CIRGenFunction::generateCode(clang::GlobalDecl gd, cir::FuncOp fn,
632632

633633
startFunction(gd, retTy, fn, funcType, args, loc, bodyRange.getBegin());
634634

635+
// Save parameters for coroutine function.
636+
if (body && isa_and_nonnull<CoroutineBodyStmt>(body))
637+
llvm::append_range(fnArgs, funcDecl->parameters());
638+
635639
if (isa<CXXDestructorDecl>(funcDecl)) {
636640
emitDestructorBody(args);
637641
} else if (isa<CXXConstructorDecl>(funcDecl)) {

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ class CIRGenFunction : public CIRGenTypeCache {
152152
/// global initializers.
153153
mlir::Operation *curFn = nullptr;
154154

155+
/// Save Parameter Decl for coroutine.
156+
llvm::SmallVector<const ParmVarDecl *> fnArgs;
157+
155158
using DeclMapTy = llvm::DenseMap<const clang::Decl *, Address>;
156159
/// This keeps track of the CIR allocas or globals for local C
157160
/// declarations.

clang/test/CIR/CodeGen/coro-task.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ struct suspend_never {
3636
void await_resume() noexcept {}
3737
};
3838

39+
struct string {
40+
int size() const;
41+
string();
42+
string(char const *s);
43+
};
44+
3945
} // namespace std
4046

4147
namespace folly {
@@ -101,7 +107,10 @@ co_invoke_fn co_invoke;
101107
}} // namespace folly::coro
102108

103109
// CIR-DAG: ![[VoidTask:.*]] = !cir.record<struct "folly::coro::Task<void>" padded {!u8i}>
104-
110+
// CIR-DAG: ![[IntTask:.*]] = !cir.record<struct "folly::coro::Task<int>" padded {!u8i}>
111+
// CIR-DAG: ![[VoidPromisse:.*]] = !cir.record<struct "folly::coro::Task<void>::promise_type" padded {!u8i}>
112+
// CIR-DAG: ![[IntPromisse:.*]] = !cir.record<struct "folly::coro::Task<int>::promise_type" padded {!u8i}>
113+
// CIR-DAG: ![[StdString:.*]] = !cir.record<struct "std::string" padded {!u8i}>
105114
// CIR: module {{.*}} {
106115
// CIR-NEXT: cir.global external @_ZN5folly4coro9co_invokeE = #cir.zero : !rec_folly3A3Acoro3A3Aco_invoke_fn
107116

@@ -119,6 +128,7 @@ VoidTask silly_task() {
119128
// CIR: cir.func coroutine dso_local @_Z10silly_taskv() -> ![[VoidTask]]
120129
// CIR: %[[VoidTaskAddr:.*]] = cir.alloca ![[VoidTask]], {{.*}}, ["__retval"]
121130
// CIR: %[[SavedFrameAddr:.*]] = cir.alloca !cir.ptr<!void>, !cir.ptr<!cir.ptr<!void>>, ["__coro_frame_addr"]
131+
// CIR: %[[VoidPromisseAddr:.*]] = cir.alloca ![[VoidPromisse]], {{.*}}, ["__promise"]
122132

123133
// Get coroutine id with __builtin_coro_id.
124134

@@ -138,3 +148,27 @@ VoidTask silly_task() {
138148
// CIR: }
139149
// CIR: %[[Load0:.*]] = cir.load{{.*}} %[[SavedFrameAddr]] : !cir.ptr<!cir.ptr<!void>>, !cir.ptr<!void>
140150
// CIR: %[[CoroFrameAddr:.*]] = cir.call @__builtin_coro_begin(%[[CoroId]], %[[Load0]])
151+
152+
// Call promise.get_return_object() to retrieve the task object.
153+
154+
// CIR: %[[RetObj:.*]] = cir.call @_ZN5folly4coro4TaskIvE12promise_type17get_return_objectEv(%[[VoidPromisseAddr]]) nothrow : {{.*}} -> ![[VoidTask]]
155+
// CIR: cir.store{{.*}} %[[RetObj]], %[[VoidTaskAddr]] : ![[VoidTask]]
156+
157+
folly::coro::Task<int> byRef(const std::string& s) {
158+
co_return s.size();
159+
}
160+
161+
// CIR: cir.func coroutine dso_local @_Z5byRefRKSt6string(%[[ARG:.*]]: !cir.ptr<![[StdString]]> {{.*}}) -> ![[IntTask]]
162+
// CIR: %[[AllocaParam:.*]] = cir.alloca !cir.ptr<![[StdString]]>, {{.*}}, ["s", init, const]
163+
// CIR: %[[IntTaskAddr:.*]] = cir.alloca ![[IntTask]], {{.*}}, ["__retval"]
164+
// CIR: %[[SavedFrameAddr:.*]] = cir.alloca !cir.ptr<!void>, !cir.ptr<!cir.ptr<!void>>, ["__coro_frame_addr"]
165+
// CIR: %[[AllocaFnUse:.*]] = cir.alloca !cir.ptr<![[StdString]]>, {{.*}}, ["s", init, const]
166+
// CIR: %[[IntPromisseAddr:.*]] = cir.alloca ![[IntPromisse]], {{.*}}, ["__promise"]
167+
// CIR: cir.store %[[ARG]], %[[AllocaParam]] : !cir.ptr<![[StdString]]>, {{.*}}
168+
169+
// Call promise.get_return_object() to retrieve the task object.
170+
171+
// CIR: %[[LOAD:.*]] = cir.load %[[AllocaParam]] : !cir.ptr<!cir.ptr<![[StdString]]>>, !cir.ptr<![[StdString]]>
172+
// CIR: cir.store {{.*}} %[[LOAD]], %[[AllocaFnUse]] : !cir.ptr<![[StdString]]>, !cir.ptr<!cir.ptr<![[StdString]]>>
173+
// CIR: %[[RetObj:.*]] = cir.call @_ZN5folly4coro4TaskIiE12promise_type17get_return_objectEv(%4) nothrow : {{.*}} -> ![[IntTask]]
174+
// CIR: cir.store {{.*}} %[[RetObj]], %[[IntTaskAddr]] : ![[IntTask]]

0 commit comments

Comments
 (0)