-
Notifications
You must be signed in to change notification settings - Fork 15.3k
Attribute support [[clang::musttail]] in ExprConstant.cpp (work in progress)
#138477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…ows guaranteed tail recursion.
|
@llvm/pr-subscribers-clang Author: Hana Dusíková (hanickadot) ChangesThis change makes This PR is work in progress. Full diff: https://github.com/llvm/llvm-project/pull/138477.diff 1 Files Affected:
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index b79d8c197fe7d..9ef6b983d196a 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -735,6 +735,13 @@ namespace {
ScopeKind Scope)
: Value(Val, Scope), Base(Base), T(T) {}
+ Cleanup(Cleanup &&Other) noexcept
+ : Value{Other.Value}, Base{Other.Base}, T{Other.T} {
+ Other.Value = {};
+ }
+
+ Cleanup &operator=(Cleanup &&) = default;
+
/// Determine whether this cleanup should be performed at the end of the
/// given kind of scope.
bool isDestroyedAtEndOf(ScopeKind K) const {
@@ -1006,6 +1013,24 @@ namespace {
EM_IgnoreSideEffects,
} EvalMode;
+ /// Pointer to last tail recursion enabled return. Enforced with
+ /// [[clang::musttail]]
+ const ReturnStmt *TailRecursionReturnStmt = nullptr;
+
+ struct DeferRecursionFunctionCall {
+ const CallExpr *E{nullptr};
+ const FunctionDecl *Definition{nullptr};
+ bool HasThis{false};
+ APValue ThisVal{}; // can't use LValue here :(
+ llvm::ArrayRef<const clang::Expr *> Args{};
+ CallRef Call{};
+ Stmt *Body{nullptr};
+ SmallVector<QualType, 4> CovariantAdjustmentPath{};
+ SmallVector<Cleanup, 16> ArgumentsStored{};
+ };
+
+ DeferRecursionFunctionCall DeferFunctionCall{};
+
/// Are we checking whether the expression is a potential constant
/// expression?
bool checkingPotentialConstantExpression() const override {
@@ -1124,6 +1149,21 @@ namespace {
return Result;
}
+ void EnableTailRecursion(const ReturnStmt *ret) {
+ TailRecursionReturnStmt = ret;
+ }
+
+ void DisableTailRecursion() { TailRecursionReturnStmt = nullptr; }
+
+ bool TailRecursionReady() const { return DeferFunctionCall.E != nullptr; }
+
+ bool IsTailRecursion(const ReturnStmt *ret) {
+ if (TailRecursionReturnStmt != ret)
+ return false;
+ TailRecursionReturnStmt = nullptr;
+ return true;
+ }
+
/// Get the allocated storage for the given parameter of the given call.
APValue *getParamSlot(CallRef Call, const ParmVarDecl *PVD) {
CallStackFrame *Frame = getCallFrameAndDepth(Call.CallIndex).first;
@@ -1439,6 +1479,12 @@ namespace {
// instances of this class.
Info.CurrentCall->popTempVersion();
}
+
+ friend void transferFromCallScope(ScopeRAII &,
+ llvm::SmallVectorImpl<Cleanup> &);
+ friend bool transferIntoCallScope(ScopeRAII &,
+ llvm::SmallVectorImpl<Cleanup> &);
+
private:
static bool cleanup(EvalInfo &Info, bool RunDestructors,
unsigned OldStackSize) {
@@ -1457,6 +1503,10 @@ namespace {
}
}
+ compact(Info, OldStackSize);
+ return Success;
+ }
+ static void compact(EvalInfo &Info, unsigned OldStackSize) {
// Compact any retained cleanups.
auto NewEnd = Info.CleanupStack.begin() + OldStackSize;
if (Kind != ScopeKind::Block)
@@ -1465,12 +1515,47 @@ namespace {
return C.isDestroyedAtEndOf(Kind);
});
Info.CleanupStack.erase(NewEnd, Info.CleanupStack.end());
- return Success;
}
};
typedef ScopeRAII<ScopeKind::Block> BlockScopeRAII;
typedef ScopeRAII<ScopeKind::FullExpression> FullExpressionRAII;
typedef ScopeRAII<ScopeKind::Call> CallScopeRAII;
+
+ static void transferFromCallScope(CallScopeRAII &Scope,
+ llvm::SmallVectorImpl<Cleanup> &Backup) {
+ Backup.clear();
+
+ auto CurrentVariables = MutableArrayRef<Cleanup>(Scope.Info.CleanupStack)
+ .slice(Scope.OldStackSize);
+
+ // Transfer of cleanup informations of tail call outside of current scope.
+ // These variables are going to be destroyed in current scope, which only
+ // prepares the tail call, but is not doing it.
+ Backup.clear();
+
+ for (Cleanup &Lifetime : CurrentVariables) {
+ Backup.push_back(std::move(Lifetime));
+ }
+
+ // Remove lifetime management from this scope.
+ Scope.compact(Scope.Info, Scope.OldStackSize);
+ Scope.Info.CleanupStack.truncate(
+ Scope.OldStackSize); // make sure this is ok
+ assert(Scope.Info.CleanupStack.size() == Scope.OldStackSize);
+ }
+
+ static bool transferIntoCallScope(CallScopeRAII &Scope,
+ llvm::SmallVectorImpl<Cleanup> &Backup) {
+ if (!Scope.cleanup(Scope.Info, true, Scope.OldStackSize))
+ return false;
+
+ for (auto &Lifetime : Backup) {
+ Scope.Info.CleanupStack.push_back(std::move(Lifetime));
+ }
+
+ Backup.clear();
+ return true;
+ }
}
bool SubobjectDesignator::checkSubobject(EvalInfo &Info, const Expr *E,
@@ -5614,10 +5699,14 @@ static EvalStmtResult EvaluateStmt(StmtResult &Result, EvalInfo &Info,
// We know we returned, but we don't know what the value is.
return ESR_Failed;
}
- if (RetExpr &&
- !(Result.Slot
- ? EvaluateInPlace(Result.Value, Info, *Result.Slot, RetExpr)
- : Evaluate(Result.Value, Info, RetExpr)))
+
+ if (!RetExpr || !isa<CallExpr>(RetExpr)) {
+ Info.DisableTailRecursion();
+ }
+
+ if (RetExpr && !(Result.Slot ? EvaluateInPlace(Result.Value, Info,
+ *Result.Slot, RetExpr)
+ : Evaluate(Result.Value, Info, RetExpr)))
return ESR_Failed;
return Scope.destroy() ? ESR_Returned : ESR_Failed;
}
@@ -5869,32 +5958,37 @@ static EvalStmtResult EvaluateStmt(StmtResult &Result, EvalInfo &Info,
case Stmt::AttributedStmtClass: {
const auto *AS = cast<AttributedStmt>(S);
const auto *SS = AS->getSubStmt();
+ const auto *RS = dyn_cast<ReturnStmt>(SS);
MSConstexprContextRAII ConstexprContext(
- *Info.CurrentCall, hasSpecificAttr<MSConstexprAttr>(AS->getAttrs()) &&
- isa<ReturnStmt>(SS));
+ *Info.CurrentCall,
+ hasSpecificAttr<MSConstexprAttr>(AS->getAttrs()) && RS != nullptr);
auto LO = Info.getASTContext().getLangOpts();
- if (LO.CXXAssumptions && !LO.MSVCCompat) {
- for (auto *Attr : AS->getAttrs()) {
- auto *AA = dyn_cast<CXXAssumeAttr>(Attr);
- if (!AA)
- continue;
-
- auto *Assumption = AA->getAssumption();
- if (Assumption->isValueDependent())
- return ESR_Failed;
+ for (auto *Attr : AS->getAttrs()) {
+ if (auto *AA = dyn_cast<CXXAssumeAttr>(Attr)) {
+ // This branch handles C++'s [[assume(<EXPR>)]]
+ if (LO.CXXAssumptions && !LO.MSVCCompat) {
+ auto *Assumption = AA->getAssumption();
+ if (Assumption->isValueDependent())
+ return ESR_Failed;
- if (Assumption->HasSideEffects(Info.getASTContext()))
- continue;
+ if (Assumption->HasSideEffects(Info.getASTContext()))
+ continue;
- bool Value;
- if (!EvaluateAsBooleanCondition(Assumption, Value, Info))
- return ESR_Failed;
- if (!Value) {
- Info.CCEDiag(Assumption->getExprLoc(),
- diag::note_constexpr_assumption_failed);
- return ESR_Failed;
+ bool Value;
+ if (!EvaluateAsBooleanCondition(Assumption, Value, Info))
+ return ESR_Failed;
+ if (!Value) {
+ Info.CCEDiag(Assumption->getExprLoc(),
+ diag::note_constexpr_assumption_failed);
+ return ESR_Failed;
+ }
}
+ } else if (isa<MustTailAttr>(Attr) && RS != nullptr) {
+ // This branch handles [[clang::mustttail]] enforcement on
+ // tail-recursion which is strict and already checked, otherwise it will
+ // fail to compile.
+ Info.EnableTailRecursion(RS);
}
}
@@ -6514,16 +6608,16 @@ static bool MaybeHandleUnionActiveMemberChange(EvalInfo &Info,
static bool EvaluateCallArg(const ParmVarDecl *PVD, const Expr *Arg,
CallRef Call, EvalInfo &Info,
- bool NonNull = false) {
+ CallStackFrame &CallerFrame, bool NonNull = false) {
LValue LV;
// Create the parameter slot and register its destruction. For a vararg
// argument, create a temporary.
// FIXME: For calling conventions that destroy parameters in the callee,
// should we consider performing destruction when the function returns
// instead?
- APValue &V = PVD ? Info.CurrentCall->createParam(Call, PVD, LV)
- : Info.CurrentCall->createTemporary(Arg, Arg->getType(),
- ScopeKind::Call, LV);
+ APValue &V = PVD ? CallerFrame.createParam(Call, PVD, LV)
+ : CallerFrame.createTemporary(Arg, Arg->getType(),
+ ScopeKind::Call, LV);
if (!EvaluateInPlace(V, Info, LV, Arg))
return false;
@@ -6539,8 +6633,8 @@ static bool EvaluateCallArg(const ParmVarDecl *PVD, const Expr *Arg,
/// Evaluate the arguments to a function call.
static bool EvaluateArgs(ArrayRef<const Expr *> Args, CallRef Call,
- EvalInfo &Info, const FunctionDecl *Callee,
- bool RightToLeft = false) {
+ EvalInfo &Info, CallStackFrame &CallerFrame,
+ const FunctionDecl *Callee, bool RightToLeft = false) {
bool Success = true;
llvm::SmallBitVector ForbiddenNullArgs;
if (Callee->hasAttr<NonNullAttr>()) {
@@ -6563,7 +6657,7 @@ static bool EvaluateArgs(ArrayRef<const Expr *> Args, CallRef Call,
const ParmVarDecl *PVD =
Idx < Callee->getNumParams() ? Callee->getParamDecl(Idx) : nullptr;
bool NonNull = !ForbiddenNullArgs.empty() && ForbiddenNullArgs[Idx];
- if (!EvaluateCallArg(PVD, Args[Idx], Call, Info, NonNull)) {
+ if (!EvaluateCallArg(PVD, Args[Idx], Call, Info, CallerFrame, NonNull)) {
// If we're checking for a potential constant expression, evaluate all
// initializers even if some of them fail.
if (!Info.noteFailure())
@@ -6650,6 +6744,44 @@ static bool HandleFunctionCall(SourceLocation CallLoc,
return ESR == ESR_Returned;
}
+static void HandleTailCallTransfer(
+ EvalInfo &Info, const CallExpr *E, const FunctionDecl *Definition,
+ const LValue *This, LValue &ThisVal,
+ llvm::ArrayRef<const clang::Expr *> Args, CallRef Call, Stmt *Body,
+ SmallVector<QualType, 4> &CovariantAdjustmentPath, CallScopeRAII &Scope) {
+ auto &defer = Info.DeferFunctionCall;
+
+ defer.E = E;
+ defer.Definition = Definition;
+ defer.HasThis = This != nullptr;
+ ThisVal.moveInto(defer.ThisVal);
+ defer.Args = Args;
+ defer.Call = Call;
+ defer.Body = Body;
+ defer.CovariantAdjustmentPath = std::move(CovariantAdjustmentPath);
+
+ transferFromCallScope(Scope, defer.ArgumentsStored);
+}
+
+static bool HandleTailCallSetup(
+ EvalInfo &Info, const CallExpr *&E, const FunctionDecl *&Definition,
+ LValue *&This, LValue &ThisVal, llvm::ArrayRef<const clang::Expr *> &Args,
+ CallRef &Call, Stmt *&Body,
+ SmallVector<QualType, 4> &CovariantAdjustmentPath, CallScopeRAII &Scope) {
+ auto &defer = Info.DeferFunctionCall;
+ assert(defer.E != nullptr);
+
+ E = std::exchange(defer.E, nullptr);
+ Definition = defer.Definition;
+ ThisVal.setFrom(Info.Ctx, defer.ThisVal);
+ This = defer.HasThis ? &ThisVal : nullptr;
+ Args = defer.Args;
+ Call = defer.Call;
+ Body = defer.Body;
+ CovariantAdjustmentPath = std::move(defer.CovariantAdjustmentPath);
+ return transferIntoCallScope(Scope, defer.ArgumentsStored);
+}
+
/// Evaluate a constructor call.
static bool HandleConstructorCall(const Expr *E, const LValue &This,
CallRef Call,
@@ -6871,7 +7003,7 @@ static bool HandleConstructorCall(const Expr *E, const LValue &This,
EvalInfo &Info, APValue &Result) {
CallScopeRAII CallScope(Info);
CallRef Call = Info.CurrentCall->createCall(Definition);
- if (!EvaluateArgs(Args, Call, Info, Definition))
+ if (!EvaluateArgs(Args, Call, Info, *Info.CurrentCall, Definition))
return false;
return HandleConstructorCall(E, This, Call, Definition, Info, Result) &&
@@ -8242,6 +8374,13 @@ class ExprEvaluatorBase
APValue Result;
if (!handleCallExpr(E, Result, nullptr))
return false;
+
+ // When our current call is defered as a tail recursion
+ // we can't change result (yet).
+ if (Info.DeferFunctionCall.E != nullptr) {
+ return true;
+ }
+
return DerivedSuccess(Result, E);
}
@@ -8257,6 +8396,11 @@ class ExprEvaluatorBase
auto Args = llvm::ArrayRef(E->getArgs(), E->getNumArgs());
bool HasQualifier = false;
+ // Check for tail recursion, before we start evaluating any internal
+ // expression which can steal tail on their own.
+ const bool TailRecursion =
+ std::exchange(Info.TailRecursionReturnStmt, nullptr) != nullptr;
+
CallRef Call;
// Extract function decl and 'this' pointer from the callee.
@@ -8317,12 +8461,15 @@ class ExprEvaluatorBase
auto *OCE = dyn_cast<CXXOperatorCallExpr>(E);
if (OCE && OCE->isAssignmentOp()) {
assert(Args.size() == 2 && "wrong number of arguments in assignment");
- Call = Info.CurrentCall->createCall(FD);
bool HasThis = false;
if (const auto *MD = dyn_cast<CXXMethodDecl>(FD))
HasThis = MD->isImplicitObjectMemberFunction();
- if (!EvaluateArgs(HasThis ? Args.slice(1) : Args, Call, Info, FD,
- /*RightToLeft=*/true))
+
+ CallStackFrame &CallOriginFrame =
+ *(TailRecursion ? Info.CurrentCall->Caller : Info.CurrentCall);
+ Call = CallOriginFrame.createCall(FD);
+ if (!EvaluateArgs(HasThis ? Args.slice(1) : Args, Call, Info,
+ CallOriginFrame, FD, /*RightToLeft = */ true))
return false;
}
@@ -8404,8 +8551,10 @@ class ExprEvaluatorBase
// Evaluate the arguments now if we've not already done so.
if (!Call) {
- Call = Info.CurrentCall->createCall(FD);
- if (!EvaluateArgs(Args, Call, Info, FD))
+ CallStackFrame &CallOriginFrame =
+ *(TailRecursion ? Info.CurrentCall->Caller : Info.CurrentCall);
+ Call = CallOriginFrame.createCall(FD);
+ if (!EvaluateArgs(Args, Call, Info, CallOriginFrame, FD))
return false;
}
@@ -8438,11 +8587,40 @@ class ExprEvaluatorBase
const FunctionDecl *Definition = nullptr;
Stmt *Body = FD->getBody(Definition);
- if (!CheckConstexprFunction(Info, E->getExprLoc(), FD, Definition, Body) ||
- !HandleFunctionCall(E->getExprLoc(), Definition, This, E, Args, Call,
+ if (!CheckConstexprFunction(Info, E->getExprLoc(), FD, Definition, Body)) {
+ return false;
+ }
+
+ // If we are doing tail recursion, we need to store everything needed for
+ // the function call. There is always max one tail recursion prepared during
+ // execution of a program.
+ if (TailRecursion) {
+ HandleTailCallTransfer(Info, E, Definition, This, ThisVal, Args, Call,
+ Body, CovariantAdjustmentPath, CallScope);
+ return true;
+ }
+
+ if (!HandleFunctionCall(E->getExprLoc(), Definition, This, E, Args, Call,
Body, Info, Result, ResultSlot))
return false;
+ // If we do tail recursion, we don't have result yet.
+ assert(!Info.TailRecursionReady() || Result.isAbsent());
+
+ // A tail recursion can result in another tail recursion, so we need to loop
+ // here.
+ while (Info.TailRecursionReady()) {
+ if (!HandleTailCallSetup(Info, E, Definition, This, ThisVal, Args, Call,
+ Body, CovariantAdjustmentPath, CallScope))
+ return false;
+
+ if (!HandleFunctionCall(E->getExprLoc(), Definition, This, E, Args, Call,
+ Body, Info, Result, ResultSlot))
+ return false;
+ }
+
+ // TODO checkme this is correct
+ // We got out of tail recursion, it was just a normal function.
if (!CovariantAdjustmentPath.empty() &&
!HandleCovariantReturnAdjustment(Info, E, Result,
CovariantAdjustmentPath))
@@ -17832,7 +18010,7 @@ bool Expr::EvaluateWithSubstitution(APValue &Value, ASTContext &Ctx,
break;
const ParmVarDecl *PVD = Callee->getParamDecl(Idx);
if ((*I)->isValueDependent() ||
- !EvaluateCallArg(PVD, *I, Call, Info) ||
+ !EvaluateCallArg(PVD, *I, Call, Info, *Info.CurrentCall) ||
Info.EvalStatus.HasSideEffects) {
// If evaluation fails, throw away the argument entirely.
if (APValue *Slot = Info.getParamSlot(Call, PVD))
|
constexpr int deep_test(int remaining) {
if (remaining == 0) {
return 42;
}
MUST_TAIL return deep_test(remaining - 1);
}
constexpr int result = deep_test(200'000);With debug build of clang: clang++ -std=c++26 -c deep.cpp -fconstexpr-depth=1 -DMUST_TAIL="[[clang::musttail]]"
real 0m7.227s
user 0m7.199s
sys 0m0.019sclang++ -std=c++26 -c deep.cpp -fconstexpr-depth=200000 -DMUST_TAIL=
Segmentation fault: 11(the recursion crashes around depth 1600) |
This change makes
[[clang::musttail]]work for Constant Evaluation. Function calls marked with this attribute won't use system stack, but will loop after nearest function call. The attribute is already very strict, and checks all problematic cases (non-trivial destructors, referencing local variables).The attribute already exists, this is just performance improvement.
This code's execution is now shallow and in
VisitCallExprit only prepare the call and its arguments. But it will execute the call after current function's call is finished.This PR is work in progress.