Skip to content

Commit 5c43385

Browse files
[CIR] Upstream CIR await op (#168133)
This PR upstreams `cir.await` and adds initial codegen for emitting a skeleton of the ready, suspend, and resume branches. Codegen for these branches is left for a future PR. It also adds a test for the invalid case where a `cir.func` is marked as a coroutine but does not contain a `cir.await` op in its body.
1 parent eea6215 commit 5c43385

File tree

12 files changed

+383
-5
lines changed

12 files changed

+383
-5
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -802,8 +802,8 @@ def CIR_ConditionOp : CIR_Op<"condition", [
802802
//===----------------------------------------------------------------------===//
803803

804804
defvar CIR_YieldableScopes = [
805-
"ArrayCtor", "ArrayDtor", "CaseOp", "DoWhileOp", "ForOp", "GlobalOp", "IfOp",
806-
"ScopeOp", "SwitchOp", "TernaryOp", "WhileOp", "TryOp"
805+
"ArrayCtor", "ArrayDtor", "AwaitOp", "CaseOp", "DoWhileOp", "ForOp",
806+
"GlobalOp", "IfOp", "ScopeOp", "SwitchOp", "TernaryOp", "WhileOp", "TryOp"
807807
];
808808

809809
def CIR_YieldOp : CIR_Op<"yield", [
@@ -2752,6 +2752,100 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
27522752
];
27532753
}
27542754

2755+
//===----------------------------------------------------------------------===//
2756+
// AwaitOp
2757+
//===----------------------------------------------------------------------===//
2758+
2759+
def CIR_AwaitKind : CIR_I32EnumAttr<"AwaitKind", "await kind", [
2760+
I32EnumAttrCase<"Init", 0, "init">,
2761+
I32EnumAttrCase<"User", 1, "user">,
2762+
I32EnumAttrCase<"Yield", 2, "yield">,
2763+
I32EnumAttrCase<"Final", 3, "final">
2764+
]>;
2765+
2766+
def CIR_AwaitOp : CIR_Op<"await",[
2767+
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
2768+
RecursivelySpeculatable, NoRegionArguments
2769+
]> {
2770+
let summary = "Wraps C++ co_await implicit logic";
2771+
let description = [{
2772+
The under the hood effect of using C++ `co_await expr` roughly
2773+
translates to:
2774+
2775+
```c++
2776+
// co_await expr;
2777+
2778+
auto &&x = CommonExpr();
2779+
if (!x.await_ready()) {
2780+
...
2781+
x.await_suspend(...);
2782+
...
2783+
}
2784+
x.await_resume();
2785+
```
2786+
2787+
`cir.await` represents this logic by using 3 regions:
2788+
- ready: covers veto power from x.await_ready()
2789+
- suspend: wraps actual x.await_suspend() logic
2790+
- resume: handles x.await_resume()
2791+
2792+
Breaking this up in regions allows individual scrutiny of conditions
2793+
which might lead to folding some of them out. Lowerings coming out
2794+
of CIR, e.g. LLVM, should use the `suspend` region to track more
2795+
lower level codegen (e.g. intrinsic emission for coro.save/coro.suspend).
2796+
2797+
There are also 4 flavors of `cir.await` available:
2798+
- `init`: compiler generated initial suspend via implicit `co_await`.
2799+
- `user`: also known as normal, representing a user written `co_await`.
2800+
- `yield`: user written `co_yield` expressions.
2801+
- `final`: compiler generated final suspend via implicit `co_await`.
2802+
2803+
```mlir
2804+
cir.scope {
2805+
... // auto &&x = CommonExpr();
2806+
cir.await(user, ready : {
2807+
... // x.await_ready()
2808+
}, suspend : {
2809+
... // x.await_suspend()
2810+
}, resume : {
2811+
... // x.await_resume()
2812+
})
2813+
}
2814+
```
2815+
2816+
Note that resulution of the common expression is assumed to happen
2817+
as part of the enclosing await scope.
2818+
}];
2819+
2820+
let arguments = (ins CIR_AwaitKind:$kind);
2821+
let regions = (region SizedRegion<1>:$ready,
2822+
SizedRegion<1>:$suspend,
2823+
SizedRegion<1>:$resume);
2824+
let assemblyFormat = [{
2825+
`(` $kind `,`
2826+
`ready` `:` $ready `,`
2827+
`suspend` `:` $suspend `,`
2828+
`resume` `:` $resume `,`
2829+
`)`
2830+
attr-dict
2831+
}];
2832+
2833+
let skipDefaultBuilders = 1;
2834+
let builders = [
2835+
OpBuilder<(ins
2836+
"cir::AwaitKind":$kind,
2837+
CArg<"BuilderCallbackRef",
2838+
"nullptr">:$readyBuilder,
2839+
CArg<"BuilderCallbackRef",
2840+
"nullptr">:$suspendBuilder,
2841+
CArg<"BuilderCallbackRef",
2842+
"nullptr">:$resumeBuilder
2843+
)>
2844+
];
2845+
2846+
let hasVerifier = 1;
2847+
}
2848+
27552849
//===----------------------------------------------------------------------===//
27562850
// CopyOp
27572851
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ using namespace clang;
2222
using namespace clang::CIRGen;
2323

2424
struct clang::CIRGen::CGCoroData {
25+
// What is the current await expression kind and how many
26+
// await/yield expressions were encountered so far.
27+
// These are used to generate pretty labels for await expressions in LLVM IR.
28+
cir::AwaitKind currentAwaitKind = cir::AwaitKind::Init;
2529
// Stores the __builtin_coro_id emitted in the function so that we can supply
2630
// it as the first argument to other builtins.
2731
cir::CallOp coroId = nullptr;
@@ -249,7 +253,114 @@ CIRGenFunction::emitCoroutineBody(const CoroutineBodyStmt &s) {
249253
emitAnyExprToMem(s.getReturnValue(), returnValue,
250254
s.getReturnValue()->getType().getQualifiers(),
251255
/*isInit*/ true);
256+
257+
assert(!cir::MissingFeatures::ehCleanupScope());
258+
// FIXME(cir): EHStack.pushCleanup<CallCoroEnd>(EHCleanup);
259+
curCoro.data->currentAwaitKind = cir::AwaitKind::Init;
260+
if (emitStmt(s.getInitSuspendStmt(), /*useCurrentScope=*/true).failed())
261+
return mlir::failure();
252262
assert(!cir::MissingFeatures::emitBodyAndFallthrough());
253263
}
254264
return mlir::success();
255265
}
266+
// Given a suspend expression which roughly looks like:
267+
//
268+
// auto && x = CommonExpr();
269+
// if (!x.await_ready()) {
270+
// x.await_suspend(...); (*)
271+
// }
272+
// x.await_resume();
273+
//
274+
// where the result of the entire expression is the result of x.await_resume()
275+
//
276+
// (*) If x.await_suspend return type is bool, it allows to veto a suspend:
277+
// if (x.await_suspend(...))
278+
// llvm_coro_suspend();
279+
//
280+
// This is more higher level than LLVM codegen, for that one see llvm's
281+
// docs/Coroutines.rst for more details.
282+
namespace {
283+
struct LValueOrRValue {
284+
LValue lv;
285+
RValue rv;
286+
};
287+
} // namespace
288+
289+
static LValueOrRValue
290+
emitSuspendExpression(CIRGenFunction &cgf, CGCoroData &coro,
291+
CoroutineSuspendExpr const &s, cir::AwaitKind kind,
292+
AggValueSlot aggSlot, bool ignoreResult,
293+
mlir::Block *scopeParentBlock,
294+
mlir::Value &tmpResumeRValAddr, bool forLValue) {
295+
mlir::LogicalResult awaitBuild = mlir::success();
296+
LValueOrRValue awaitRes;
297+
298+
CIRGenFunction::OpaqueValueMapping binder =
299+
CIRGenFunction::OpaqueValueMapping(cgf, s.getOpaqueValue());
300+
CIRGenBuilderTy &builder = cgf.getBuilder();
301+
[[maybe_unused]] cir::AwaitOp awaitOp = cir::AwaitOp::create(
302+
builder, cgf.getLoc(s.getSourceRange()), kind,
303+
/*readyBuilder=*/
304+
[&](mlir::OpBuilder &b, mlir::Location loc) {
305+
builder.createCondition(
306+
cgf.createDummyValue(loc, cgf.getContext().BoolTy));
307+
},
308+
/*suspendBuilder=*/
309+
[&](mlir::OpBuilder &b, mlir::Location loc) {
310+
cir::YieldOp::create(builder, loc);
311+
},
312+
/*resumeBuilder=*/
313+
[&](mlir::OpBuilder &b, mlir::Location loc) {
314+
cir::YieldOp::create(builder, loc);
315+
});
316+
317+
assert(awaitBuild.succeeded() && "Should know how to codegen");
318+
return awaitRes;
319+
}
320+
321+
static RValue emitSuspendExpr(CIRGenFunction &cgf,
322+
const CoroutineSuspendExpr &e,
323+
cir::AwaitKind kind, AggValueSlot aggSlot,
324+
bool ignoreResult) {
325+
RValue rval;
326+
mlir::Location scopeLoc = cgf.getLoc(e.getSourceRange());
327+
328+
// Since we model suspend / resume as an inner region, we must store
329+
// resume scalar results in a tmp alloca, and load it after we build the
330+
// suspend expression. An alternative way to do this would be to make
331+
// every region return a value when promise.return_value() is used, but
332+
// it's a bit awkward given that resume is the only region that actually
333+
// returns a value.
334+
mlir::Block *currEntryBlock = cgf.curLexScope->getEntryBlock();
335+
[[maybe_unused]] mlir::Value tmpResumeRValAddr;
336+
337+
// No need to explicitly wrap this into a scope since the AST already uses a
338+
// ExprWithCleanups, which will wrap this into a cir.scope anyways.
339+
rval = emitSuspendExpression(cgf, *cgf.curCoro.data, e, kind, aggSlot,
340+
ignoreResult, currEntryBlock, tmpResumeRValAddr,
341+
/*forLValue*/ false)
342+
.rv;
343+
344+
if (ignoreResult || rval.isIgnored())
345+
return rval;
346+
347+
if (rval.isScalar()) {
348+
rval = RValue::get(cir::LoadOp::create(cgf.getBuilder(), scopeLoc,
349+
rval.getValue().getType(),
350+
tmpResumeRValAddr));
351+
} else if (rval.isAggregate()) {
352+
// This is probably already handled via AggSlot, remove this assertion
353+
// once we have a testcase and prove all pieces work.
354+
cgf.cgm.errorNYI("emitSuspendExpr Aggregate");
355+
} else { // complex
356+
cgf.cgm.errorNYI("emitSuspendExpr Complex");
357+
}
358+
return rval;
359+
}
360+
361+
RValue CIRGenFunction::emitCoawaitExpr(const CoawaitExpr &e,
362+
AggValueSlot aggSlot,
363+
bool ignoreResult) {
364+
return emitSuspendExpr(*this, e, curCoro.data->currentAwaitKind, aggSlot,
365+
ignoreResult);
366+
}

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,10 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
156156
return cgf.emitLoadOfLValue(lv, e->getExprLoc()).getValue();
157157
}
158158

159+
mlir::Value VisitCoawaitExpr(CoawaitExpr *s) {
160+
return cgf.emitCoawaitExpr(*s).getValue();
161+
}
162+
159163
mlir::Value emitLoadOfLValue(LValue lv, SourceLocation loc) {
160164
return cgf.emitLoadOfLValue(lv, loc).getValue();
161165
}

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,6 +1576,9 @@ class CIRGenFunction : public CIRGenTypeCache {
15761576
void emitForwardingCallToLambda(const CXXMethodDecl *lambdaCallOperator,
15771577
CallArgList &callArgs);
15781578

1579+
RValue emitCoawaitExpr(const CoawaitExpr &e,
1580+
AggValueSlot aggSlot = AggValueSlot::ignored(),
1581+
bool ignoreResult = false);
15791582
/// Emit the computation of the specified expression of complex type,
15801583
/// returning the result.
15811584
mlir::Value emitComplexExpr(const Expr *e);

clang/lib/CIR/CodeGen/CIRGenValue.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class RValue {
4949
bool isScalar() const { return flavor == Scalar; }
5050
bool isComplex() const { return flavor == Complex; }
5151
bool isAggregate() const { return flavor == Aggregate; }
52+
bool isIgnored() const { return isScalar() && !getValue(); }
5253

5354
bool isVolatileQualified() const { return isVolatile; }
5455

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,10 @@ void cir::ConditionOp::getSuccessorRegions(
289289
regions.emplace_back(getOperation(), loopOp->getResults());
290290
}
291291

292-
assert(!cir::MissingFeatures::awaitOp());
292+
// Parent is an await: condition may branch to resume or suspend regions.
293+
auto await = cast<AwaitOp>(getOperation()->getParentOp());
294+
regions.emplace_back(&await.getResume(), await.getResume().getArguments());
295+
regions.emplace_back(&await.getSuspend(), await.getSuspend().getArguments());
293296
}
294297

295298
MutableOperandRange
@@ -299,8 +302,7 @@ cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
299302
}
300303

301304
LogicalResult cir::ConditionOp::verify() {
302-
assert(!cir::MissingFeatures::awaitOp());
303-
if (!isa<LoopOpInterface>(getOperation()->getParentOp()))
305+
if (!isa<LoopOpInterface, AwaitOp>(getOperation()->getParentOp()))
304306
return emitOpError("condition must be within a conditional region");
305307
return success();
306308
}
@@ -1910,6 +1912,19 @@ void cir::FuncOp::print(OpAsmPrinter &p) {
19101912

19111913
mlir::LogicalResult cir::FuncOp::verify() {
19121914

1915+
if (!isDeclaration() && getCoroutine()) {
1916+
bool foundAwait = false;
1917+
this->walk([&](Operation *op) {
1918+
if (auto await = dyn_cast<AwaitOp>(op)) {
1919+
foundAwait = true;
1920+
return;
1921+
}
1922+
});
1923+
if (!foundAwait)
1924+
return emitOpError()
1925+
<< "coroutine body must use at least one cir.await op";
1926+
}
1927+
19131928
llvm::SmallSet<llvm::StringRef, 16> labels;
19141929
llvm::SmallSet<llvm::StringRef, 16> gotos;
19151930
llvm::SmallSet<llvm::StringRef, 16> blockAddresses;
@@ -2149,6 +2164,61 @@ OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {
21492164

21502165
return {};
21512166
}
2167+
//===----------------------------------------------------------------------===//
2168+
// AwaitOp
2169+
//===----------------------------------------------------------------------===//
2170+
2171+
void cir::AwaitOp::build(OpBuilder &builder, OperationState &result,
2172+
cir::AwaitKind kind, BuilderCallbackRef readyBuilder,
2173+
BuilderCallbackRef suspendBuilder,
2174+
BuilderCallbackRef resumeBuilder) {
2175+
result.addAttribute(getKindAttrName(result.name),
2176+
cir::AwaitKindAttr::get(builder.getContext(), kind));
2177+
{
2178+
OpBuilder::InsertionGuard guard(builder);
2179+
Region *readyRegion = result.addRegion();
2180+
builder.createBlock(readyRegion);
2181+
readyBuilder(builder, result.location);
2182+
}
2183+
2184+
{
2185+
OpBuilder::InsertionGuard guard(builder);
2186+
Region *suspendRegion = result.addRegion();
2187+
builder.createBlock(suspendRegion);
2188+
suspendBuilder(builder, result.location);
2189+
}
2190+
2191+
{
2192+
OpBuilder::InsertionGuard guard(builder);
2193+
Region *resumeRegion = result.addRegion();
2194+
builder.createBlock(resumeRegion);
2195+
resumeBuilder(builder, result.location);
2196+
}
2197+
}
2198+
2199+
void cir::AwaitOp::getSuccessorRegions(
2200+
mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2201+
// If any index all the underlying regions branch back to the parent
2202+
// operation.
2203+
if (!point.isParent()) {
2204+
regions.push_back(
2205+
RegionSuccessor(getOperation(), getOperation()->getResults()));
2206+
return;
2207+
}
2208+
2209+
// TODO: retrieve information from the promise and only push the
2210+
// necessary ones. Example: `std::suspend_never` on initial or final
2211+
// await's might allow suspend region to be skipped.
2212+
regions.push_back(RegionSuccessor(&this->getReady()));
2213+
regions.push_back(RegionSuccessor(&this->getSuspend()));
2214+
regions.push_back(RegionSuccessor(&this->getResume()));
2215+
}
2216+
2217+
LogicalResult cir::AwaitOp::verify() {
2218+
if (!isa<ConditionOp>(this->getReady().back().getTerminator()))
2219+
return emitOpError("ready region must end with cir.condition");
2220+
return success();
2221+
}
21522222

21532223
//===----------------------------------------------------------------------===//
21542224
// CopyOp Definitions

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3931,6 +3931,12 @@ mlir::LogicalResult CIRToLLVMBlockAddressOpLowering::matchAndRewrite(
39313931
return mlir::failure();
39323932
}
39333933

3934+
mlir::LogicalResult CIRToLLVMAwaitOpLowering::matchAndRewrite(
3935+
cir::AwaitOp op, OpAdaptor adaptor,
3936+
mlir::ConversionPatternRewriter &rewriter) const {
3937+
return mlir::failure();
3938+
}
3939+
39343940
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
39353941
return std::make_unique<ConvertCIRToLLVMPass>();
39363942
}

clang/lib/CodeGen/CGValue.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class RValue {
6464
bool isScalar() const { return Flavor == Scalar; }
6565
bool isComplex() const { return Flavor == Complex; }
6666
bool isAggregate() const { return Flavor == Aggregate; }
67+
bool isIgnored() const { return isScalar() && !getScalarVal(); }
6768

6869
bool isVolatileQualified() const { return IsVolatile; }
6970

0 commit comments

Comments
 (0)