Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 96 additions & 2 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -802,8 +802,8 @@ def CIR_ConditionOp : CIR_Op<"condition", [
//===----------------------------------------------------------------------===//

defvar CIR_YieldableScopes = [
"ArrayCtor", "ArrayDtor", "CaseOp", "DoWhileOp", "ForOp", "GlobalOp", "IfOp",
"ScopeOp", "SwitchOp", "TernaryOp", "WhileOp", "TryOp"
"ArrayCtor", "ArrayDtor", "AwaitOp", "CaseOp", "DoWhileOp", "ForOp",
"GlobalOp", "IfOp", "ScopeOp", "SwitchOp", "TernaryOp", "WhileOp", "TryOp"
];

def CIR_YieldOp : CIR_Op<"yield", [
Expand Down Expand Up @@ -2752,6 +2752,100 @@ def CIR_CallOp : CIR_CallOpBase<"call", [NoRegionArguments]> {
];
}

//===----------------------------------------------------------------------===//
// AwaitOp
//===----------------------------------------------------------------------===//

def CIR_AwaitKind : CIR_I32EnumAttr<"AwaitKind", "await kind", [
I32EnumAttrCase<"Init", 0, "init">,
I32EnumAttrCase<"User", 1, "user">,
I32EnumAttrCase<"Yield", 2, "yield">,
I32EnumAttrCase<"Final", 3, "final">
]>;

def CIR_AwaitOp : CIR_Op<"await",[
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
RecursivelySpeculatable, NoRegionArguments
]> {
let summary = "Wraps C++ co_await implicit logic";
let description = [{
The under the hood effect of using C++ `co_await expr` roughly
translates to:

```c++
// co_await expr;

auto &&x = CommonExpr();
if (!x.await_ready()) {
...
x.await_suspend(...);
...
}
x.await_resume();
```

`cir.await` represents this logic by using 3 regions:
- ready: covers veto power from x.await_ready()
- suspend: wraps actual x.await_suspend() logic
- resume: handles x.await_resume()

Breaking this up in regions allows individual scrutiny of conditions
which might lead to folding some of them out. Lowerings coming out
of CIR, e.g. LLVM, should use the `suspend` region to track more
lower level codegen (e.g. intrinsic emission for coro.save/coro.suspend).

There are also 4 flavors of `cir.await` available:
- `init`: compiler generated initial suspend via implicit `co_await`.
- `user`: also known as normal, representing a user written `co_await`.
- `yield`: user written `co_yield` expressions.
- `final`: compiler generated final suspend via implicit `co_await`.

```mlir
cir.scope {
... // auto &&x = CommonExpr();
cir.await(user, ready : {
... // x.await_ready()
}, suspend : {
... // x.await_suspend()
}, resume : {
... // x.await_resume()
})
}
```

Note that resulution of the common expression is assumed to happen
as part of the enclosing await scope.
}];

let arguments = (ins CIR_AwaitKind:$kind);
let regions = (region SizedRegion<1>:$ready,
SizedRegion<1>:$suspend,
SizedRegion<1>:$resume);
let assemblyFormat = [{
`(` $kind `,`
`ready` `:` $ready `,`
`suspend` `:` $suspend `,`
`resume` `:` $resume `,`
`)`
attr-dict
}];

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins
"cir::AwaitKind":$kind,
CArg<"BuilderCallbackRef",
"nullptr">:$readyBuilder,
CArg<"BuilderCallbackRef",
"nullptr">:$suspendBuilder,
CArg<"BuilderCallbackRef",
"nullptr">:$resumeBuilder
)>
];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// CopyOp
//===----------------------------------------------------------------------===//
Expand Down
111 changes: 111 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenCoroutine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ using namespace clang;
using namespace clang::CIRGen;

struct clang::CIRGen::CGCoroData {
// What is the current await expression kind and how many
// await/yield expressions were encountered so far.
// These are used to generate pretty labels for await expressions in LLVM IR.
cir::AwaitKind currentAwaitKind = cir::AwaitKind::Init;
// Stores the __builtin_coro_id emitted in the function so that we can supply
// it as the first argument to other builtins.
cir::CallOp coroId = nullptr;
Expand Down Expand Up @@ -249,7 +253,114 @@ CIRGenFunction::emitCoroutineBody(const CoroutineBodyStmt &s) {
emitAnyExprToMem(s.getReturnValue(), returnValue,
s.getReturnValue()->getType().getQualifiers(),
/*isInit*/ true);

assert(!cir::MissingFeatures::ehCleanupScope());
// FIXME(cir): EHStack.pushCleanup<CallCoroEnd>(EHCleanup);
curCoro.data->currentAwaitKind = cir::AwaitKind::Init;
if (emitStmt(s.getInitSuspendStmt(), /*useCurrentScope=*/true).failed())
return mlir::failure();
assert(!cir::MissingFeatures::emitBodyAndFallthrough());
}
return mlir::success();
}
// Given a suspend expression which roughly looks like:
//
// auto && x = CommonExpr();
// if (!x.await_ready()) {
// x.await_suspend(...); (*)
// }
// x.await_resume();
//
// where the result of the entire expression is the result of x.await_resume()
//
// (*) If x.await_suspend return type is bool, it allows to veto a suspend:
// if (x.await_suspend(...))
// llvm_coro_suspend();
//
// This is more higher level than LLVM codegen, for that one see llvm's
// docs/Coroutines.rst for more details.
namespace {
struct LValueOrRValue {
LValue lv;
RValue rv;
};
} // namespace

static LValueOrRValue
emitSuspendExpression(CIRGenFunction &cgf, CGCoroData &coro,
CoroutineSuspendExpr const &s, cir::AwaitKind kind,
AggValueSlot aggSlot, bool ignoreResult,
mlir::Block *scopeParentBlock,
mlir::Value &tmpResumeRValAddr, bool forLValue) {
mlir::LogicalResult awaitBuild = mlir::success();
LValueOrRValue awaitRes;

CIRGenFunction::OpaqueValueMapping binder =
CIRGenFunction::OpaqueValueMapping(cgf, s.getOpaqueValue());
CIRGenBuilderTy &builder = cgf.getBuilder();
[[maybe_unused]] cir::AwaitOp awaitOp = cir::AwaitOp::create(
builder, cgf.getLoc(s.getSourceRange()), kind,
/*readyBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
builder.createCondition(
cgf.createDummyValue(loc, cgf.getContext().BoolTy));
},
/*suspendBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
cir::YieldOp::create(builder, loc);
},
/*resumeBuilder=*/
[&](mlir::OpBuilder &b, mlir::Location loc) {
cir::YieldOp::create(builder, loc);
});

assert(awaitBuild.succeeded() && "Should know how to codegen");
return awaitRes;
}

static RValue emitSuspendExpr(CIRGenFunction &cgf,
const CoroutineSuspendExpr &e,
cir::AwaitKind kind, AggValueSlot aggSlot,
bool ignoreResult) {
RValue rval;
mlir::Location scopeLoc = cgf.getLoc(e.getSourceRange());

// Since we model suspend / resume as an inner region, we must store
// resume scalar results in a tmp alloca, and load it after we build the
// suspend expression. An alternative way to do this would be to make
// every region return a value when promise.return_value() is used, but
// it's a bit awkward given that resume is the only region that actually
// returns a value.
mlir::Block *currEntryBlock = cgf.curLexScope->getEntryBlock();
[[maybe_unused]] mlir::Value tmpResumeRValAddr;

// No need to explicitly wrap this into a scope since the AST already uses a
// ExprWithCleanups, which will wrap this into a cir.scope anyways.
rval = emitSuspendExpression(cgf, *cgf.curCoro.data, e, kind, aggSlot,
ignoreResult, currEntryBlock, tmpResumeRValAddr,
/*forLValue*/ false)
.rv;

if (ignoreResult || rval.isIgnored())
return rval;

if (rval.isScalar()) {
rval = RValue::get(cir::LoadOp::create(cgf.getBuilder(), scopeLoc,
rval.getValue().getType(),
tmpResumeRValAddr));
} else if (rval.isAggregate()) {
// This is probably already handled via AggSlot, remove this assertion
// once we have a testcase and prove all pieces work.
cgf.cgm.errorNYI("emitSuspendExpr Aggregate");
} else { // complex
cgf.cgm.errorNYI("emitSuspendExpr Complex");
}
return rval;
}

RValue CIRGenFunction::emitCoawaitExpr(const CoawaitExpr &e,
AggValueSlot aggSlot,
bool ignoreResult) {
return emitSuspendExpr(*this, e, curCoro.data->currentAwaitKind, aggSlot,
ignoreResult);
}
4 changes: 4 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
return cgf.emitLoadOfLValue(lv, e->getExprLoc()).getValue();
}

mlir::Value VisitCoawaitExpr(CoawaitExpr *s) {
return cgf.emitCoawaitExpr(*s).getValue();
}

mlir::Value emitLoadOfLValue(LValue lv, SourceLocation loc) {
return cgf.emitLoadOfLValue(lv, loc).getValue();
}
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/CIR/CodeGen/CIRGenFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,9 @@ class CIRGenFunction : public CIRGenTypeCache {
void emitForwardingCallToLambda(const CXXMethodDecl *lambdaCallOperator,
CallArgList &callArgs);

RValue emitCoawaitExpr(const CoawaitExpr &e,
AggValueSlot aggSlot = AggValueSlot::ignored(),
bool ignoreResult = false);
/// Emit the computation of the specified expression of complex type,
/// returning the result.
mlir::Value emitComplexExpr(const Expr *e);
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/CodeGen/CIRGenValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class RValue {
bool isScalar() const { return flavor == Scalar; }
bool isComplex() const { return flavor == Complex; }
bool isAggregate() const { return flavor == Aggregate; }
bool isIgnored() const { return isScalar() && !getValue(); }

bool isVolatileQualified() const { return isVolatile; }

Expand Down
76 changes: 73 additions & 3 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,10 @@ void cir::ConditionOp::getSuccessorRegions(
regions.emplace_back(getOperation(), loopOp->getResults());
}

assert(!cir::MissingFeatures::awaitOp());
// Parent is an await: condition may branch to resume or suspend regions.
auto await = cast<AwaitOp>(getOperation()->getParentOp());
regions.emplace_back(&await.getResume(), await.getResume().getArguments());
regions.emplace_back(&await.getSuspend(), await.getSuspend().getArguments());
}

MutableOperandRange
Expand All @@ -299,8 +302,7 @@ cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
}

LogicalResult cir::ConditionOp::verify() {
assert(!cir::MissingFeatures::awaitOp());
if (!isa<LoopOpInterface>(getOperation()->getParentOp()))
if (!isa<LoopOpInterface, AwaitOp>(getOperation()->getParentOp()))
return emitOpError("condition must be within a conditional region");
return success();
}
Expand Down Expand Up @@ -1910,6 +1912,19 @@ void cir::FuncOp::print(OpAsmPrinter &p) {

mlir::LogicalResult cir::FuncOp::verify() {

if (!isDeclaration() && getCoroutine()) {
bool foundAwait = false;
this->walk([&](Operation *op) {
if (auto await = dyn_cast<AwaitOp>(op)) {
foundAwait = true;
return;
}
});
if (!foundAwait)
return emitOpError()
<< "coroutine body must use at least one cir.await op";
}

llvm::SmallSet<llvm::StringRef, 16> labels;
llvm::SmallSet<llvm::StringRef, 16> gotos;
llvm::SmallSet<llvm::StringRef, 16> blockAddresses;
Expand Down Expand Up @@ -2149,6 +2164,61 @@ OpFoldResult cir::UnaryOp::fold(FoldAdaptor adaptor) {

return {};
}
//===----------------------------------------------------------------------===//
// AwaitOp
//===----------------------------------------------------------------------===//

void cir::AwaitOp::build(OpBuilder &builder, OperationState &result,
cir::AwaitKind kind, BuilderCallbackRef readyBuilder,
BuilderCallbackRef suspendBuilder,
BuilderCallbackRef resumeBuilder) {
result.addAttribute(getKindAttrName(result.name),
cir::AwaitKindAttr::get(builder.getContext(), kind));
{
OpBuilder::InsertionGuard guard(builder);
Region *readyRegion = result.addRegion();
builder.createBlock(readyRegion);
readyBuilder(builder, result.location);
}

{
OpBuilder::InsertionGuard guard(builder);
Region *suspendRegion = result.addRegion();
builder.createBlock(suspendRegion);
suspendBuilder(builder, result.location);
}

{
OpBuilder::InsertionGuard guard(builder);
Region *resumeRegion = result.addRegion();
builder.createBlock(resumeRegion);
resumeBuilder(builder, result.location);
}
}

void cir::AwaitOp::getSuccessorRegions(
mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
// If any index all the underlying regions branch back to the parent
// operation.
if (!point.isParent()) {
regions.push_back(
RegionSuccessor(getOperation(), getOperation()->getResults()));
return;
}

// TODO: retrieve information from the promise and only push the
// necessary ones. Example: `std::suspend_never` on initial or final
// await's might allow suspend region to be skipped.
regions.push_back(RegionSuccessor(&this->getReady()));
regions.push_back(RegionSuccessor(&this->getSuspend()));
regions.push_back(RegionSuccessor(&this->getResume()));
}

LogicalResult cir::AwaitOp::verify() {
if (!isa<ConditionOp>(this->getReady().back().getTerminator()))
return emitOpError("ready region must end with cir.condition");
return success();
}

//===----------------------------------------------------------------------===//
// CopyOp Definitions
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3841,6 +3841,12 @@ mlir::LogicalResult CIRToLLVMBlockAddressOpLowering::matchAndRewrite(
return mlir::failure();
}

mlir::LogicalResult CIRToLLVMAwaitOpLowering::matchAndRewrite(
cir::AwaitOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
return mlir::failure();
}

std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
return std::make_unique<ConvertCIRToLLVMPass>();
}
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CodeGen/CGValue.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class RValue {
bool isScalar() const { return Flavor == Scalar; }
bool isComplex() const { return Flavor == Complex; }
bool isAggregate() const { return Flavor == Aggregate; }
bool isIgnored() const { return isScalar() && !getScalarVal(); }

bool isVolatileQualified() const { return IsVolatile; }

Expand Down
Loading