Skip to content

[Clang][Coroutines] Introducing the [[clang::coro_inplace_task]] attribute #94693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions clang/include/clang/AST/ExprCXX.h
Original file line number Diff line number Diff line change
Expand Up @@ -5082,18 +5082,19 @@ class CoroutineSuspendExpr : public Expr {
enum SubExpr { Operand, Common, Ready, Suspend, Resume, Count };

Stmt *SubExprs[SubExpr::Count];
OpaqueValueExpr *OpaqueValue = nullptr;
OpaqueValueExpr *CommonExprOpaqueValue = nullptr;
OpaqueValueExpr *InplaceCallOpaqueValue = nullptr;

public:
// These types correspond to the three C++ 'await_suspend' return variants
enum class SuspendReturnType { SuspendVoid, SuspendBool, SuspendHandle };

CoroutineSuspendExpr(StmtClass SC, SourceLocation KeywordLoc, Expr *Operand,
Expr *Common, Expr *Ready, Expr *Suspend, Expr *Resume,
OpaqueValueExpr *OpaqueValue)
OpaqueValueExpr *CommonExprOpaqueValue)
: Expr(SC, Resume->getType(), Resume->getValueKind(),
Resume->getObjectKind()),
KeywordLoc(KeywordLoc), OpaqueValue(OpaqueValue) {
KeywordLoc(KeywordLoc), CommonExprOpaqueValue(CommonExprOpaqueValue) {
SubExprs[SubExpr::Operand] = Operand;
SubExprs[SubExpr::Common] = Common;
SubExprs[SubExpr::Ready] = Ready;
Expand Down Expand Up @@ -5128,7 +5129,16 @@ class CoroutineSuspendExpr : public Expr {
}

/// getOpaqueValue - Return the opaque value placeholder.
OpaqueValueExpr *getOpaqueValue() const { return OpaqueValue; }
OpaqueValueExpr *getCommonExprOpaqueValue() const {
return CommonExprOpaqueValue;
}

OpaqueValueExpr *getInplaceCallOpaqueValue() const {
return InplaceCallOpaqueValue;
}
void setInplaceCallOpaqueValue(OpaqueValueExpr *E) {
InplaceCallOpaqueValue = E;
}

Expr *getReadyExpr() const {
return static_cast<Expr*>(SubExprs[SubExpr::Ready]);
Expand Down Expand Up @@ -5194,9 +5204,9 @@ class CoawaitExpr : public CoroutineSuspendExpr {
public:
CoawaitExpr(SourceLocation CoawaitLoc, Expr *Operand, Expr *Common,
Expr *Ready, Expr *Suspend, Expr *Resume,
OpaqueValueExpr *OpaqueValue, bool IsImplicit = false)
OpaqueValueExpr *CommonExprOpaqueValue, bool IsImplicit = false)
: CoroutineSuspendExpr(CoawaitExprClass, CoawaitLoc, Operand, Common,
Ready, Suspend, Resume, OpaqueValue) {
Ready, Suspend, Resume, CommonExprOpaqueValue) {
CoawaitBits.IsImplicit = IsImplicit;
}

Expand Down Expand Up @@ -5275,9 +5285,9 @@ class CoyieldExpr : public CoroutineSuspendExpr {
public:
CoyieldExpr(SourceLocation CoyieldLoc, Expr *Operand, Expr *Common,
Expr *Ready, Expr *Suspend, Expr *Resume,
OpaqueValueExpr *OpaqueValue)
OpaqueValueExpr *CommonExprOpaqueValue)
: CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Operand, Common,
Ready, Suspend, Resume, OpaqueValue) {}
Ready, Suspend, Resume, CommonExprOpaqueValue) {}
CoyieldExpr(SourceLocation CoyieldLoc, QualType Ty, Expr *Operand,
Expr *Common)
: CoroutineSuspendExpr(CoyieldExprClass, CoyieldLoc, Ty, Operand,
Expand Down
8 changes: 8 additions & 0 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,14 @@ def CoroDisableLifetimeBound : InheritableAttr {
let SimpleHandler = 1;
}

def CoroInplaceTask : InheritableAttr {
let Spellings = [Clang<"coro_inplace_task">];
let Subjects = SubjectList<[CXXRecord]>;
let LangOpts = [CPlusPlus];
let Documentation = [CoroInplaceTaskDoc];
let SimpleHandler = 1;
}

// OSObject-based attributes.
def OSConsumed : InheritableParamAttr {
let Spellings = [Clang<"os_consumed">];
Expand Down
19 changes: 19 additions & 0 deletions clang/include/clang/Basic/AttrDocs.td
Original file line number Diff line number Diff line change
Expand Up @@ -8108,6 +8108,25 @@ but do not pass them to the underlying coroutine or pass them by value.
}];
}

def CoroInplaceTaskDoc : Documentation {
let Category = DocCatDecl;
let Content = [{
The ``[[clang::coro_inplace_task]]`` is a class attribute which can be applied
to a coroutine return type.

When a coroutine function that returns such a type calls another coroutine function,
the compiler performs heap allocation elision when the following conditions are all met:
- callee coroutine function returns a type that is annotated with ``[[clang::coro_inplace_task]]``.
- The callee coroutine function is inlined.
- In caller coroutine, the return value of the callee is a prvalue or an xvalue, and
- The temporary expression containing the callee coroutine object is immediately co_awaited.

The behavior is undefined if any of the following condition was met:
- the caller coroutine is destroyed earlier than the callee coroutine.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to carefully analyze any attributes that add new forms of undefined behavior. How do we expect the user to avoid this case? Is there some way we can make the behavior here deterministic? If we can't make it deterministic, is there some sanitizer that would catch this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the scrutiny here. In coroutine's case, developers don't author Task types themselves usually. The use case of this attribute is mostly within library/framework code. The attribute should only be used when such a library needs to communicate such a guarantee to the compiler.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make sure we're clear about exactly which case we're talking about, can you write an example that triggers undefined behavior?

I'm not sure I see the connection between writing a task type and ensuring coroutines are destroyed in the right order... are you saying that a well-behaved Task will ensure destruction always happens in the right order, regardless of how it's used?

I'd still like an answer to my question about sanitizers.

Copy link
Member Author

@yuxuanchen1997 yuxuanchen1997 Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make sure we're clear about exactly which case we're talking about, can you write an example that triggers undefined behavior?

Sure. Though the UB needs to be triggered from a place that's either:

  1. you have access to the handle to the callee after the caller has been destroyed.
  2. you destroy the caller of a currently running callee (potentially from another thread of execution).

An example would be

std::coroutine_handle<> await_suspend(std::coroutine_handle<> caller_handle) {
  caller_handle.destroy();
  return this->handle;
}

A task type whose associated awaiter implements its await_suspend like this should not be attributed structured concurrency. Same goes for other customization points where you get a hold of handles from both caller and the callee. Same goes for APIs in Task and Awaiter types that help other code extract both handles.

are you saying that a well-behaved Task will ensure destruction always happens in the right order, regardless of how it's used?

Yes. This is the assumption. The Task/Promise and even Awaiter types holding this attribute should not save/allow extraction of callee handle for the purpose of resumption. When such a way to break the structuredness is provided, the Task type should not be attributed as coro_structured_concurrency. This patch has no intention to eradicate the use of nonstructured concurrency. There are legitimate uses of them. It's just close to impossible to perform HALO.

I'd still like an answer to my question about sanitizers.

Missed this one in my prior response. The UB triggered from violation of the contract is effectively a use-after-free. Not a sanitizers expert, but ASan sounds like able to catch this?


}];
}

def CountedByDocs : Documentation {
let Category = DocCatField;
let Content = [{
Expand Down
5 changes: 3 additions & 2 deletions clang/lib/CodeGen/CGBlocks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,8 @@ llvm::Type *CodeGenModule::getGenericBlockLiteralType() {
}

RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E,
ReturnValueSlot ReturnValue) {
ReturnValueSlot ReturnValue,
llvm::CallBase **CallOrInvoke) {
const auto *BPT = E->getCallee()->getType()->castAs<BlockPointerType>();
llvm::Value *BlockPtr = EmitScalarExpr(E->getCallee());
llvm::Type *GenBlockTy = CGM.getGenericBlockLiteralType();
Expand Down Expand Up @@ -1220,7 +1221,7 @@ RValue CodeGenFunction::EmitBlockCallExpr(const CallExpr *E,
CGCallee Callee(CGCalleeInfo(), Func);

// And call the block.
return EmitCall(FnInfo, Callee, ReturnValue, Args);
return EmitCall(FnInfo, Callee, ReturnValue, Args, CallOrInvoke);
}

Address CodeGenFunction::GetAddrOfBlockDecl(const VarDecl *variable) {
Expand Down
5 changes: 3 additions & 2 deletions clang/lib/CodeGen/CGCUDARuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ CGCUDARuntime::~CGCUDARuntime() {}

RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF,
const CUDAKernelCallExpr *E,
ReturnValueSlot ReturnValue) {
ReturnValueSlot ReturnValue,
llvm::CallBase **CallOrInvoke) {
llvm::BasicBlock *ConfigOKBlock = CGF.createBasicBlock("kcall.configok");
llvm::BasicBlock *ContBlock = CGF.createBasicBlock("kcall.end");

Expand All @@ -35,7 +36,7 @@ RValue CGCUDARuntime::EmitCUDAKernelCallExpr(CodeGenFunction &CGF,

eval.begin(CGF);
CGF.EmitBlock(ConfigOKBlock);
CGF.EmitSimpleCallExpr(E, ReturnValue);
CGF.EmitSimpleCallExpr(E, ReturnValue, CallOrInvoke);
CGF.EmitBranch(ContBlock);

CGF.EmitBlock(ContBlock);
Expand Down
8 changes: 5 additions & 3 deletions clang/lib/CodeGen/CGCUDARuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/IR/GlobalValue.h"

namespace llvm {
class CallBase;
class Function;
class GlobalVariable;
}
Expand Down Expand Up @@ -82,9 +83,10 @@ class CGCUDARuntime {
CGCUDARuntime(CodeGenModule &CGM) : CGM(CGM) {}
virtual ~CGCUDARuntime();

virtual RValue EmitCUDAKernelCallExpr(CodeGenFunction &CGF,
const CUDAKernelCallExpr *E,
ReturnValueSlot ReturnValue);
virtual RValue
EmitCUDAKernelCallExpr(CodeGenFunction &CGF, const CUDAKernelCallExpr *E,
ReturnValueSlot ReturnValue,
llvm::CallBase **CallOrInvoke = nullptr);

/// Emits a kernel launch stub.
virtual void emitDeviceStub(CodeGenFunction &CGF, FunctionArgList &Args) = 0;
Expand Down
10 changes: 5 additions & 5 deletions clang/lib/CodeGen/CGCXXABI.h
Original file line number Diff line number Diff line change
Expand Up @@ -485,11 +485,11 @@ class CGCXXABI {
llvm::PointerUnion<const CXXDeleteExpr *, const CXXMemberCallExpr *>;

/// Emit the ABI-specific virtual destructor call.
virtual llvm::Value *EmitVirtualDestructorCall(CodeGenFunction &CGF,
const CXXDestructorDecl *Dtor,
CXXDtorType DtorType,
Address This,
DeleteOrMemberCallExpr E) = 0;
virtual llvm::Value *
EmitVirtualDestructorCall(CodeGenFunction &CGF, const CXXDestructorDecl *Dtor,
CXXDtorType DtorType, Address This,
DeleteOrMemberCallExpr E,
llvm::CallBase **CallOrInvoke) = 0;

virtual void adjustCallArgsForDestructorThunk(CodeGenFunction &CGF,
GlobalDecl GD,
Expand Down
16 changes: 6 additions & 10 deletions clang/lib/CodeGen/CGClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2191,15 +2191,11 @@ static bool canEmitDelegateCallArgs(CodeGenFunction &CGF,
return true;
}

void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D,
CXXCtorType Type,
bool ForVirtualBase,
bool Delegating,
Address This,
CallArgList &Args,
AggValueSlot::Overlap_t Overlap,
SourceLocation Loc,
bool NewPointerIsChecked) {
void CodeGenFunction::EmitCXXConstructorCall(
const CXXConstructorDecl *D, CXXCtorType Type, bool ForVirtualBase,
bool Delegating, Address This, CallArgList &Args,
AggValueSlot::Overlap_t Overlap, SourceLocation Loc,
bool NewPointerIsChecked, llvm::CallBase **CallOrInvoke) {
const CXXRecordDecl *ClassDecl = D->getParent();

if (!NewPointerIsChecked)
Expand Down Expand Up @@ -2247,7 +2243,7 @@ void CodeGenFunction::EmitCXXConstructorCall(const CXXConstructorDecl *D,
const CGFunctionInfo &Info = CGM.getTypes().arrangeCXXConstructorCall(
Args, D, Type, ExtraArgs.Prefix, ExtraArgs.Suffix, PassPrototypeArgs);
CGCallee Callee = CGCallee::forDirect(CalleePtr, GlobalDecl(D, Type));
EmitCall(Info, Callee, ReturnValueSlot(), Args, nullptr, false, Loc);
EmitCall(Info, Callee, ReturnValueSlot(), Args, CallOrInvoke, false, Loc);

// Generate vtable assumptions if we're constructing a complete object
// with a vtable. We don't do this for base subobjects for two reasons:
Expand Down
30 changes: 21 additions & 9 deletions clang/lib/CodeGen/CGCoroutine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@

#include "CGCleanup.h"
#include "CodeGenFunction.h"
#include "llvm/ADT/ScopeExit.h"
#include "clang/AST/ExprCXX.h"
#include "clang/AST/StmtCXX.h"
#include "clang/AST/StmtVisitor.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/IR/Intrinsics.h"

using namespace clang;
using namespace CodeGen;
Expand Down Expand Up @@ -223,12 +225,22 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co
CoroutineSuspendExpr const &S,
AwaitKind Kind, AggValueSlot aggSlot,
bool ignoreResult, bool forLValue) {
auto *E = S.getCommonExpr();
auto &Builder = CGF.Builder;

auto CommonBinder =
CodeGenFunction::OpaqueValueMappingData::bind(CGF, S.getOpaqueValue(), E);
auto UnbindCommonOnExit =
llvm::make_scope_exit([&] { CommonBinder.unbind(CGF); });
// If S.getInplaceCallOpaqueValue() is null, we don't have a nested opaque
// value for common expression.
std::optional<CodeGenFunction::OpaqueValueMapping> OperandMapping;
if (auto *CallOV = S.getInplaceCallOpaqueValue()) {
auto *CE = cast<CallExpr>(CallOV->getSourceExpr());
llvm::CallBase *CallOrInvoke = nullptr;
LValue CallResult = CGF.EmitCallExprLValue(CE, &CallOrInvoke);
if (CallOrInvoke)
CallOrInvoke->addAnnotationMetadata("coro_must_elide");

OperandMapping.emplace(CGF, CallOV, CallResult);
}
CodeGenFunction::OpaqueValueMapping BindCommon(CGF,
S.getCommonExprOpaqueValue());

auto Prefix = buildSuspendPrefixStr(Coro, Kind);
BasicBlock *ReadyBlock = CGF.createBasicBlock(Prefix + Twine(".ready"));
Expand All @@ -241,7 +253,6 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co
// Otherwise, emit suspend logic.
CGF.EmitBlock(SuspendBlock);

auto &Builder = CGF.Builder;
llvm::Function *CoroSave = CGF.CGM.getIntrinsic(llvm::Intrinsic::coro_save);
auto *NullPtr = llvm::ConstantPointerNull::get(CGF.CGM.Int8PtrTy);
auto *SaveCall = Builder.CreateCall(CoroSave, {NullPtr});
Expand All @@ -256,7 +267,8 @@ static LValueOrRValue emitSuspendExpression(CodeGenFunction &CGF, CGCoroData &Co

SmallVector<llvm::Value *, 3> SuspendIntrinsicCallArgs;
SuspendIntrinsicCallArgs.push_back(
CGF.getOrCreateOpaqueLValueMapping(S.getOpaqueValue()).getPointer(CGF));
CGF.getOrCreateOpaqueLValueMapping(S.getCommonExprOpaqueValue())
.getPointer(CGF));

SuspendIntrinsicCallArgs.push_back(CGF.CurCoro.Data->CoroBegin);
SuspendIntrinsicCallArgs.push_back(SuspendWrapper);
Expand Down Expand Up @@ -455,7 +467,7 @@ CodeGenFunction::generateAwaitSuspendWrapper(Twine const &CoroName,
Builder.CreateLoad(GetAddrOfLocalVar(&FrameDecl));

auto AwaiterBinder = CodeGenFunction::OpaqueValueMappingData::bind(
*this, S.getOpaqueValue(), AwaiterLValue);
*this, S.getCommonExprOpaqueValue(), AwaiterLValue);

auto *SuspendRet = EmitScalarExpr(S.getSuspendExpr());

Expand Down
41 changes: 25 additions & 16 deletions clang/lib/CodeGen/CGExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5437,24 +5437,25 @@ RValue CodeGenFunction::EmitRValueForField(LValue LV,
//===--------------------------------------------------------------------===//

RValue CodeGenFunction::EmitCallExpr(const CallExpr *E,
ReturnValueSlot ReturnValue) {
ReturnValueSlot ReturnValue,
llvm::CallBase **CallOrInvoke) {
// Builtins never have block type.
if (E->getCallee()->getType()->isBlockPointerType())
return EmitBlockCallExpr(E, ReturnValue);
return EmitBlockCallExpr(E, ReturnValue, CallOrInvoke);

if (const auto *CE = dyn_cast<CXXMemberCallExpr>(E))
return EmitCXXMemberCallExpr(CE, ReturnValue);
return EmitCXXMemberCallExpr(CE, ReturnValue, CallOrInvoke);

if (const auto *CE = dyn_cast<CUDAKernelCallExpr>(E))
return EmitCUDAKernelCallExpr(CE, ReturnValue);
return EmitCUDAKernelCallExpr(CE, ReturnValue, CallOrInvoke);

// A CXXOperatorCallExpr is created even for explicit object methods, but
// these should be treated like static function call.
if (const auto *CE = dyn_cast<CXXOperatorCallExpr>(E))
if (const auto *MD =
dyn_cast_if_present<CXXMethodDecl>(CE->getCalleeDecl());
MD && MD->isImplicitObjectMemberFunction())
return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue);
return EmitCXXOperatorMemberCallExpr(CE, MD, ReturnValue, CallOrInvoke);

CGCallee callee = EmitCallee(E->getCallee());

Expand All @@ -5467,14 +5468,17 @@ RValue CodeGenFunction::EmitCallExpr(const CallExpr *E,
return EmitCXXPseudoDestructorExpr(callee.getPseudoDestructorExpr());
}

return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue);
return EmitCall(E->getCallee()->getType(), callee, E, ReturnValue,
/*Chain=*/nullptr, CallOrInvoke);
}

/// Emit a CallExpr without considering whether it might be a subclass.
RValue CodeGenFunction::EmitSimpleCallExpr(const CallExpr *E,
ReturnValueSlot ReturnValue) {
ReturnValueSlot ReturnValue,
llvm::CallBase **CallOrInvoke) {
CGCallee Callee = EmitCallee(E->getCallee());
return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue);
return EmitCall(E->getCallee()->getType(), Callee, E, ReturnValue,
/*Chain=*/nullptr, CallOrInvoke);
}

// Detect the unusual situation where an inline version is shadowed by a
Expand Down Expand Up @@ -5678,8 +5682,9 @@ LValue CodeGenFunction::EmitBinaryOperatorLValue(const BinaryOperator *E) {
llvm_unreachable("bad evaluation kind");
}

LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E) {
RValue RV = EmitCallExpr(E);
LValue CodeGenFunction::EmitCallExprLValue(const CallExpr *E,
llvm::CallBase **CallOrInvoke) {
RValue RV = EmitCallExpr(E, ReturnValueSlot(), CallOrInvoke);

if (!RV.isScalar())
return MakeAddrLValue(RV.getAggregateAddress(), E->getType(),
Expand Down Expand Up @@ -5802,9 +5807,11 @@ LValue CodeGenFunction::EmitStmtExprLValue(const StmtExpr *E) {
AlignmentSource::Decl);
}

RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee,
const CallExpr *E, ReturnValueSlot ReturnValue,
llvm::Value *Chain) {
RValue CodeGenFunction::EmitCall(QualType CalleeType,
const CGCallee &OrigCallee, const CallExpr *E,
ReturnValueSlot ReturnValue,
llvm::Value *Chain,
llvm::CallBase **CallOrInvoke) {
// Get the actual function type. The callee type will always be a pointer to
// function type or a block pointer type.
assert(CalleeType->isFunctionPointerType() &&
Expand Down Expand Up @@ -6015,8 +6022,8 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee
Address(Handle, Handle->getType(), CGM.getPointerAlign()));
Callee.setFunctionPointer(Stub);
}
llvm::CallBase *CallOrInvoke = nullptr;
RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &CallOrInvoke,
llvm::CallBase *LocalCallOrInvoke = nullptr;
RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &LocalCallOrInvoke,
E == MustTailCall, E->getExprLoc());

// Generate function declaration DISuprogram in order to be used
Expand All @@ -6025,11 +6032,13 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, const CGCallee &OrigCallee
if (auto *CalleeDecl = dyn_cast_or_null<FunctionDecl>(TargetDecl)) {
FunctionArgList Args;
QualType ResTy = BuildFunctionArgList(CalleeDecl, Args);
DI->EmitFuncDeclForCallSite(CallOrInvoke,
DI->EmitFuncDeclForCallSite(LocalCallOrInvoke,
DI->getFunctionType(CalleeDecl, ResTy, Args),
CalleeDecl);
}
}
if (CallOrInvoke)
*CallOrInvoke = LocalCallOrInvoke;

return Call;
}
Expand Down
Loading
Loading