Skip to content

Commit baaec5c

Browse files
[CIR] Upstream CIR await op
1 parent 66da12a commit baaec5c

File tree

12 files changed

+389
-5
lines changed

12 files changed

+389
-5
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -790,8 +790,8 @@ def CIR_ConditionOp : CIR_Op<"condition", [
790790
//===----------------------------------------------------------------------===//
791791

792792
defvar CIR_YieldableScopes = [
793-
"ArrayCtor", "ArrayDtor", "CaseOp", "DoWhileOp", "ForOp", "GlobalOp", "IfOp",
794-
"ScopeOp", "SwitchOp", "TernaryOp", "WhileOp", "TryOp"
793+
"ArrayCtor", "ArrayDtor", "AwaitOp", "CaseOp", "DoWhileOp", "ForOp",
794+
"GlobalOp", "IfOp", "ScopeOp", "SwitchOp", "TernaryOp", "WhileOp", "TryOp"
795795
];
796796

797797
def CIR_YieldOp : CIR_Op<"yield", [
@@ -2707,6 +2707,102 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
27072707
];
27082708
}
27092709

2710+
//===----------------------------------------------------------------------===//
2711+
// AwaitOp
2712+
//===----------------------------------------------------------------------===//
2713+
2714+
def CIR_AwaitKind : CIR_I32EnumAttr<"AwaitKind", "await kind", [
2715+
I32EnumAttrCase<"Init", 0, "init">,
2716+
I32EnumAttrCase<"User", 1, "user">,
2717+
I32EnumAttrCase<"Yield", 2, "yield">,
2718+
I32EnumAttrCase<"Final", 3, "final">
2719+
]>;
2720+
2721+
def CIR_AwaitOp : CIR_Op<"await",[
2722+
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
2723+
RecursivelySpeculatable, NoRegionArguments
2724+
]> {
2725+
let summary = "Wraps C++ co_await implicit logic";
2726+
let description = [{
2727+
The under the hood effect of using C++ `co_await expr` roughly
2728+
translates to:
2729+
2730+
```c++
2731+
// co_await expr;
2732+
2733+
auto &&x = CommonExpr();
2734+
if (!x.await_ready()) {
2735+
...
2736+
x.await_suspend(...);
2737+
...
2738+
}
2739+
x.await_resume();
2740+
```
2741+
2742+
`cir.await` represents this logic by using 3 regions:
2743+
- ready: covers veto power from x.await_ready()
2744+
- suspend: wraps actual x.await_suspend() logic
2745+
- resume: handles x.await_resume()
2746+
2747+
Breaking this up in regions allow individual scrutiny of conditions
2748+
which might lead to folding some of them out. Lowerings coming out
2749+
of CIR, e.g. LLVM, should use the `suspend` region to track more
2750+
lower level codegen (e.g. intrinsic emission for coro.save/coro.suspend).
2751+
2752+
There are also 4 flavors of `cir.await` available:
2753+
- `init`: compiler generated initial suspend via implicit `co_await`.
2754+
- `user`: also known as normal, representing user written co_await's.
2755+
- `yield`: user written `co_yield` expressions.
2756+
- `final`: compiler generated final suspend via implicit `co_await`.
2757+
2758+
From the C++ snippet we get:
2759+
2760+
```mlir
2761+
cir.scope {
2762+
... // auto &&x = CommonExpr();
2763+
cir.await(user, ready : {
2764+
... // x.await_ready()
2765+
}, suspend : {
2766+
... // x.await_suspend()
2767+
}, resume : {
2768+
... // x.await_resume()
2769+
})
2770+
}
2771+
```
2772+
2773+
Note that resulution of the common expression is assumed to happen
2774+
as part of the enclosing await scope.
2775+
}];
2776+
2777+
let arguments = (ins CIR_AwaitKind:$kind);
2778+
let regions = (region SizedRegion<1>:$ready,
2779+
SizedRegion<1>:$suspend,
2780+
SizedRegion<1>:$resume);
2781+
let assemblyFormat = [{
2782+
`(` $kind `,`
2783+
`ready` `:` $ready `,`
2784+
`suspend` `:` $suspend `,`
2785+
`resume` `:` $resume `,`
2786+
`)`
2787+
attr-dict
2788+
}];
2789+
2790+
let skipDefaultBuilders = 1;
2791+
let builders = [
2792+
OpBuilder<(ins
2793+
"cir::AwaitKind":$kind,
2794+
CArg<"BuilderCallbackRef",
2795+
"nullptr">:$readyBuilder,
2796+
CArg<"BuilderCallbackRef",
2797+
"nullptr">:$suspendBuilder,
2798+
CArg<"BuilderCallbackRef",
2799+
"nullptr">:$resumeBuilder
2800+
)>
2801+
];
2802+
2803+
let hasVerifier = 1;
2804+
}
2805+
27102806
//===----------------------------------------------------------------------===//
27112807
// CopyOp
27122808
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ using namespace clang;
2222
using namespace clang::CIRGen;
2323

2424
struct clang::CIRGen::CGCoroData {
25+
// What is the current await expression kind and how many
26+
// await/yield expressions were encountered so far.
27+
// These are used to generate pretty labels for await expressions in LLVM IR.
28+
cir::AwaitKind currentAwaitKind = cir::AwaitKind::Init;
2529
// Stores the __builtin_coro_id emitted in the function so that we can supply
2630
// it as the first argument to other builtins.
2731
cir::CallOp coroId = nullptr;
@@ -249,7 +253,114 @@ CIRGenFunction::emitCoroutineBody(const CoroutineBodyStmt &s) {
249253
emitAnyExprToMem(s.getReturnValue(), returnValue,
250254
s.getReturnValue()->getType().getQualifiers(),
251255
/*isInit*/ true);
256+
257+
assert(!cir::MissingFeatures::ehCleanupScope());
258+
// FIXME(cir): EHStack.pushCleanup<CallCoroEnd>(EHCleanup);
259+
curCoro.data->currentAwaitKind = cir::AwaitKind::Init;
260+
if (emitStmt(s.getInitSuspendStmt(), /*useCurrentScope=*/true).failed())
261+
return mlir::failure();
252262
assert(!cir::MissingFeatures::emitBodyAndFallthrough());
253263
}
254264
return mlir::success();
255265
}
266+
// Given a suspend expression which roughly looks like:
267+
//
268+
// auto && x = CommonExpr();
269+
// if (!x.await_ready()) {
270+
// x.await_suspend(...); (*)
271+
// }
272+
// x.await_resume();
273+
//
274+
// where the result of the entire expression is the result of x.await_resume()
275+
//
276+
// (*) If x.await_suspend return type is bool, it allows to veto a suspend:
277+
// if (x.await_suspend(...))
278+
// llvm_coro_suspend();
279+
//
280+
// This is more higher level than LLVM codegen, for that one see llvm's
281+
// docs/Coroutines.rst for more details.
282+
namespace {
283+
struct LValueOrRValue {
284+
LValue lv;
285+
RValue rv;
286+
};
287+
} // namespace
288+
289+
static LValueOrRValue
290+
emitSuspendExpression(CIRGenFunction &cgf, CGCoroData &coro,
291+
CoroutineSuspendExpr const &s, cir::AwaitKind kind,
292+
AggValueSlot aggSlot, bool ignoreResult,
293+
mlir::Block *scopeParentBlock,
294+
mlir::Value &tmpResumeRValAddr, bool forLValue) {
295+
mlir::LogicalResult awaitBuild = mlir::success();
296+
LValueOrRValue awaitRes;
297+
298+
CIRGenFunction::OpaqueValueMapping binder =
299+
CIRGenFunction::OpaqueValueMapping(cgf, s.getOpaqueValue());
300+
CIRGenBuilderTy &builder = cgf.getBuilder();
301+
[[maybe_unused]] cir::AwaitOp awaitOp = cir::AwaitOp::create(
302+
builder, cgf.getLoc(s.getSourceRange()), kind,
303+
/*readyBuilder=*/
304+
[&](mlir::OpBuilder &b, mlir::Location loc) {
305+
builder.createCondition(
306+
cgf.createDummyValue(loc, cgf.getContext().BoolTy));
307+
},
308+
/*suspendBuilder=*/
309+
[&](mlir::OpBuilder &b, mlir::Location loc) {
310+
cir::YieldOp::create(builder, loc);
311+
},
312+
/*resumeBuilder=*/
313+
[&](mlir::OpBuilder &b, mlir::Location loc) {
314+
cir::YieldOp::create(builder, loc);
315+
});
316+
317+
assert(awaitBuild.succeeded() && "Should know how to codegen");
318+
return awaitRes;
319+
}
320+
321+
static RValue emitSuspendExpr(CIRGenFunction &cgf,
322+
const CoroutineSuspendExpr &e,
323+
cir::AwaitKind kind, AggValueSlot aggSlot,
324+
bool ignoreResult) {
325+
RValue rval;
326+
mlir::Location scopeLoc = cgf.getLoc(e.getSourceRange());
327+
328+
// Since we model suspend / resume as an inner region, we must store
329+
// resume scalar results in a tmp alloca, and load it after we build the
330+
// suspend expression. An alternative way to do this would be to make
331+
// every region return a value when promise.return_value() is used, but
332+
// it's a bit awkward given that resume is the only region that actually
333+
// returns a value.
334+
mlir::Block *currEntryBlock = cgf.curLexScope->getEntryBlock();
335+
[[maybe_unused]] mlir::Value tmpResumeRValAddr;
336+
337+
// No need to explicitly wrap this into a scope since the AST already uses a
338+
// ExprWithCleanups, which will wrap this into a cir.scope anyways.
339+
rval = emitSuspendExpression(cgf, *cgf.curCoro.data, e, kind, aggSlot,
340+
ignoreResult, currEntryBlock, tmpResumeRValAddr,
341+
/*forLValue*/ false)
342+
.rv;
343+
344+
if (ignoreResult || rval.isIgnored())
345+
return rval;
346+
347+
if (rval.isScalar()) {
348+
rval = RValue::get(cir::LoadOp::create(cgf.getBuilder(), scopeLoc,
349+
rval.getValue().getType(),
350+
tmpResumeRValAddr));
351+
} else if (rval.isAggregate()) {
352+
// This is probably already handled via AggSlot, remove this assertion
353+
// once we have a testcase and prove all pieces work.
354+
cgf.cgm.errorNYI("emitSuspendExpr Aggregate");
355+
} else { // complex
356+
cgf.cgm.errorNYI("emitSuspendExpr Complex");
357+
}
358+
return rval;
359+
}
360+
361+
RValue CIRGenFunction::emitCoawaitExpr(const CoawaitExpr &e,
362+
AggValueSlot aggSlot,
363+
bool ignoreResult) {
364+
return emitSuspendExpr(*this, e, curCoro.data->currentAwaitKind, aggSlot,
365+
ignoreResult);
366+
}

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
154154
return cgf.emitLoadOfLValue(lv, e->getExprLoc()).getValue();
155155
}
156156

157+
mlir::Value VisitCoawaitExpr(CoawaitExpr *s) {
158+
return cgf.emitCoawaitExpr(*s).getValue();
159+
}
160+
157161
mlir::Value emitLoadOfLValue(LValue lv, SourceLocation loc) {
158162
return cgf.emitLoadOfLValue(lv, loc).getValue();
159163
}

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,6 +1557,9 @@ class CIRGenFunction : public CIRGenTypeCache {
15571557
void emitForwardingCallToLambda(const CXXMethodDecl *lambdaCallOperator,
15581558
CallArgList &callArgs);
15591559

1560+
RValue emitCoawaitExpr(const CoawaitExpr &e,
1561+
AggValueSlot aggSlot = AggValueSlot::ignored(),
1562+
bool ignoreResult = false);
15601563
/// Emit the computation of the specified expression of complex type,
15611564
/// returning the result.
15621565
mlir::Value emitComplexExpr(const Expr *e);

clang/lib/CIR/CodeGen/CIRGenValue.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class RValue {
4949
bool isScalar() const { return flavor == Scalar; }
5050
bool isComplex() const { return flavor == Complex; }
5151
bool isAggregate() const { return flavor == Aggregate; }
52+
bool isIgnored() const { return isScalar() && !getValue(); }
5253

5354
bool isVolatileQualified() const { return isVolatile; }
5455

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,10 @@ void cir::ConditionOp::getSuccessorRegions(
289289
regions.emplace_back(getOperation(), loopOp->getResults());
290290
}
291291

292-
assert(!cir::MissingFeatures::awaitOp());
292+
// Parent is an await: condition may branch to resume or suspend regions.
293+
auto await = cast<AwaitOp>(getOperation()->getParentOp());
294+
regions.emplace_back(&await.getResume(), await.getResume().getArguments());
295+
regions.emplace_back(&await.getSuspend(), await.getSuspend().getArguments());
293296
}
294297

295298
MutableOperandRange
@@ -299,8 +302,7 @@ cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
299302
}
300303

301304
LogicalResult cir::ConditionOp::verify() {
302-
assert(!cir::MissingFeatures::awaitOp());
303-
if (!isa<LoopOpInterface>(getOperation()->getParentOp()))
305+
if (!isa<LoopOpInterface, AwaitOp>(getOperation()->getParentOp()))
304306
return emitOpError("condition must be within a conditional region");
305307
return success();
306308
}
@@ -1900,6 +1902,19 @@ void cir::FuncOp::print(OpAsmPrinter &p) {
19001902

19011903
mlir::LogicalResult cir::FuncOp::verify() {
19021904

1905+
if (!isDeclaration() && getCoroutine()) {
1906+
bool foundAwait = false;
1907+
this->walk([&](Operation *op) {
1908+
if (auto await = dyn_cast<AwaitOp>(op)) {
1909+
foundAwait = true;
1910+
return;
1911+
}
1912+
});
1913+
if (!foundAwait)
1914+
return emitOpError()
1915+
<< "coroutine body must use at least one cir.await op";
1916+
}
1917+
19031918
llvm::SmallSet<llvm::StringRef, 16> labels;
19041919
llvm::SmallSet<llvm::StringRef, 16> gotos;
19051920

@@ -2116,6 +2131,65 @@ OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {
21162131

21172132
return {};
21182133
}
2134+
//===----------------------------------------------------------------------===//
2135+
// AwaitOp
2136+
//===----------------------------------------------------------------------===//
2137+
2138+
void cir::AwaitOp::build(OpBuilder &builder, OperationState &result,
2139+
cir::AwaitKind kind, BuilderCallbackRef readyBuilder,
2140+
BuilderCallbackRef suspendBuilder,
2141+
BuilderCallbackRef resumeBuilder) {
2142+
result.addAttribute(getKindAttrName(result.name),
2143+
cir::AwaitKindAttr::get(builder.getContext(), kind));
2144+
{
2145+
OpBuilder::InsertionGuard guard(builder);
2146+
Region *readyRegion = result.addRegion();
2147+
builder.createBlock(readyRegion);
2148+
readyBuilder(builder, result.location);
2149+
}
2150+
2151+
{
2152+
OpBuilder::InsertionGuard guard(builder);
2153+
Region *suspendRegion = result.addRegion();
2154+
builder.createBlock(suspendRegion);
2155+
suspendBuilder(builder, result.location);
2156+
}
2157+
2158+
{
2159+
OpBuilder::InsertionGuard guard(builder);
2160+
Region *resumeRegion = result.addRegion();
2161+
builder.createBlock(resumeRegion);
2162+
resumeBuilder(builder, result.location);
2163+
}
2164+
}
2165+
2166+
/// Given the region at `index`, or the parent operation if `index` is None,
2167+
/// return the successor regions. These are the regions that may be selected
2168+
/// during the flow of control. `operands` is a set of optional attributes
2169+
/// that correspond to a constant value for each operand, or null if that
2170+
/// operand is not a constant.
2171+
void cir::AwaitOp::getSuccessorRegions(
2172+
mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
2173+
// If any index all the underlying regions branch back to the parent
2174+
// operation.
2175+
if (!point.isParent()) {
2176+
regions.push_back(
2177+
RegionSuccessor(getOperation(), getOperation()->getResults()));
2178+
return;
2179+
}
2180+
2181+
// FIXME: we want to look at cond region for getting more accurate results
2182+
// if the other regions will get a chance to execute.
2183+
regions.push_back(RegionSuccessor(&this->getReady()));
2184+
regions.push_back(RegionSuccessor(&this->getSuspend()));
2185+
regions.push_back(RegionSuccessor(&this->getResume()));
2186+
}
2187+
2188+
LogicalResult cir::AwaitOp::verify() {
2189+
if (!isa<ConditionOp>(this->getReady().back().getTerminator()))
2190+
return emitOpError("ready region must end with cir.condition");
2191+
return success();
2192+
}
21192193

21202194
//===----------------------------------------------------------------------===//
21212195
// CopyOp Definitions

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3778,6 +3778,12 @@ mlir::LogicalResult CIRToLLVMVAArgOpLowering::matchAndRewrite(
37783778
return mlir::success();
37793779
}
37803780

3781+
mlir::LogicalResult CIRToLLVMAwaitOpLowering::matchAndRewrite(
3782+
cir::AwaitOp op, OpAdaptor adaptor,
3783+
mlir::ConversionPatternRewriter &rewriter) const {
3784+
return mlir::failure();
3785+
}
3786+
37813787
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
37823788
return std::make_unique<ConvertCIRToLLVMPass>();
37833789
}

clang/lib/CodeGen/CGValue.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class RValue {
6464
bool isScalar() const { return Flavor == Scalar; }
6565
bool isComplex() const { return Flavor == Complex; }
6666
bool isAggregate() const { return Flavor == Aggregate; }
67+
bool isIgnored() const { return isScalar() && !getScalarVal(); }
6768

6869
bool isVolatileQualified() const { return IsVolatile; }
6970

0 commit comments

Comments
 (0)