Skip to content

Conversation

@hanickadot
Copy link
Contributor

@hanickadot hanickadot commented May 4, 2025

This change makes [[clang::musttail]] work for Constant Evaluation. Function calls marked with this attribute won't use system stack, but will loop after nearest function call. The attribute is already very strict, and checks all problematic cases (non-trivial destructors, referencing local variables).

The attribute already exists, this is just performance improvement.

constexpr int inner(int v) {
	return v + 3;
}

constexpr int outer(int v) {
	[[clang::musttail]] return inner(v);
}

constexpr int v = outer(42);

This code's execution is now shallow and in VisitCallExpr it only prepare the call and its arguments. But it will execute the call after current function's call is finished.

This PR is work in progress.

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" labels May 4, 2025
@llvmbot
Copy link
Member

llvmbot commented May 4, 2025

@llvm/pr-subscribers-clang

Author: Hana Dusíková (hanickadot)

Changes

This change makes [[clang::musttail]] work. Function calls marked with this attribute won't use system stack, but will loop after nearest function call. The attribute is already very strick, and checks all problematic cases (non-trivial destructors, referencing local variables).

This PR is work in progress.


Full diff: https://github.com/llvm/llvm-project/pull/138477.diff

1 Files Affected:

  • (modified) clang/lib/AST/ExprConstant.cpp (+219-41)
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index b79d8c197fe7d..9ef6b983d196a 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -735,6 +735,13 @@ namespace {
             ScopeKind Scope)
         : Value(Val, Scope), Base(Base), T(T) {}
 
+    Cleanup(Cleanup &&Other) noexcept
+        : Value{Other.Value}, Base{Other.Base}, T{Other.T} {
+      Other.Value = {};
+    }
+
+    Cleanup &operator=(Cleanup &&) = default;
+
     /// Determine whether this cleanup should be performed at the end of the
     /// given kind of scope.
     bool isDestroyedAtEndOf(ScopeKind K) const {
@@ -1006,6 +1013,24 @@ namespace {
       EM_IgnoreSideEffects,
     } EvalMode;
 
+    /// Pointer to last tail recursion enabled return. Enforced with
+    /// [[clang::musttail]]
+    const ReturnStmt *TailRecursionReturnStmt = nullptr;
+
+    struct DeferRecursionFunctionCall {
+      const CallExpr *E{nullptr};
+      const FunctionDecl *Definition{nullptr};
+      bool HasThis{false};
+      APValue ThisVal{}; // can't use LValue here :(
+      llvm::ArrayRef<const clang::Expr *> Args{};
+      CallRef Call{};
+      Stmt *Body{nullptr};
+      SmallVector<QualType, 4> CovariantAdjustmentPath{};
+      SmallVector<Cleanup, 16> ArgumentsStored{};
+    };
+
+    DeferRecursionFunctionCall DeferFunctionCall{};
+
     /// Are we checking whether the expression is a potential constant
     /// expression?
     bool checkingPotentialConstantExpression() const override  {
@@ -1124,6 +1149,21 @@ namespace {
       return Result;
     }
 
+    void EnableTailRecursion(const ReturnStmt *ret) {
+      TailRecursionReturnStmt = ret;
+    }
+
+    void DisableTailRecursion() { TailRecursionReturnStmt = nullptr; }
+
+    bool TailRecursionReady() const { return DeferFunctionCall.E != nullptr; }
+
+    bool IsTailRecursion(const ReturnStmt *ret) {
+      if (TailRecursionReturnStmt != ret)
+        return false;
+      TailRecursionReturnStmt = nullptr;
+      return true;
+    }
+
     /// Get the allocated storage for the given parameter of the given call.
     APValue *getParamSlot(CallRef Call, const ParmVarDecl *PVD) {
       CallStackFrame *Frame = getCallFrameAndDepth(Call.CallIndex).first;
@@ -1439,6 +1479,12 @@ namespace {
       // instances of this class.
       Info.CurrentCall->popTempVersion();
     }
+
+    friend void transferFromCallScope(ScopeRAII &,
+                                      llvm::SmallVectorImpl<Cleanup> &);
+    friend bool transferIntoCallScope(ScopeRAII &,
+                                      llvm::SmallVectorImpl<Cleanup> &);
+
   private:
     static bool cleanup(EvalInfo &Info, bool RunDestructors,
                         unsigned OldStackSize) {
@@ -1457,6 +1503,10 @@ namespace {
         }
       }
 
+      compact(Info, OldStackSize);
+      return Success;
+    }
+    static void compact(EvalInfo &Info, unsigned OldStackSize) {
       // Compact any retained cleanups.
       auto NewEnd = Info.CleanupStack.begin() + OldStackSize;
       if (Kind != ScopeKind::Block)
@@ -1465,12 +1515,47 @@ namespace {
               return C.isDestroyedAtEndOf(Kind);
             });
       Info.CleanupStack.erase(NewEnd, Info.CleanupStack.end());
-      return Success;
     }
   };
   typedef ScopeRAII<ScopeKind::Block> BlockScopeRAII;
   typedef ScopeRAII<ScopeKind::FullExpression> FullExpressionRAII;
   typedef ScopeRAII<ScopeKind::Call> CallScopeRAII;
+
+  static void transferFromCallScope(CallScopeRAII &Scope,
+                                    llvm::SmallVectorImpl<Cleanup> &Backup) {
+    Backup.clear();
+
+    auto CurrentVariables = MutableArrayRef<Cleanup>(Scope.Info.CleanupStack)
+                                .slice(Scope.OldStackSize);
+
+    // Transfer of cleanup informations of tail call outside of current scope.
+    // These variables are going to be destroyed in current scope, which only
+    // prepares the tail call, but is not doing it.
+    Backup.clear();
+
+    for (Cleanup &Lifetime : CurrentVariables) {
+      Backup.push_back(std::move(Lifetime));
+    }
+
+    // Remove lifetime management from this scope.
+    Scope.compact(Scope.Info, Scope.OldStackSize);
+    Scope.Info.CleanupStack.truncate(
+        Scope.OldStackSize); // make sure this is ok
+    assert(Scope.Info.CleanupStack.size() == Scope.OldStackSize);
+  }
+
+  static bool transferIntoCallScope(CallScopeRAII &Scope,
+                                    llvm::SmallVectorImpl<Cleanup> &Backup) {
+    if (!Scope.cleanup(Scope.Info, true, Scope.OldStackSize))
+      return false;
+
+    for (auto &Lifetime : Backup) {
+      Scope.Info.CleanupStack.push_back(std::move(Lifetime));
+    }
+
+    Backup.clear();
+    return true;
+  }
 }
 
 bool SubobjectDesignator::checkSubobject(EvalInfo &Info, const Expr *E,
@@ -5614,10 +5699,14 @@ static EvalStmtResult EvaluateStmt(StmtResult &Result, EvalInfo &Info,
       // We know we returned, but we don't know what the value is.
       return ESR_Failed;
     }
-    if (RetExpr &&
-        !(Result.Slot
-              ? EvaluateInPlace(Result.Value, Info, *Result.Slot, RetExpr)
-              : Evaluate(Result.Value, Info, RetExpr)))
+
+    if (!RetExpr || !isa<CallExpr>(RetExpr)) {
+      Info.DisableTailRecursion();
+    }
+
+    if (RetExpr && !(Result.Slot ? EvaluateInPlace(Result.Value, Info,
+                                                   *Result.Slot, RetExpr)
+                                 : Evaluate(Result.Value, Info, RetExpr)))
       return ESR_Failed;
     return Scope.destroy() ? ESR_Returned : ESR_Failed;
   }
@@ -5869,32 +5958,37 @@ static EvalStmtResult EvaluateStmt(StmtResult &Result, EvalInfo &Info,
   case Stmt::AttributedStmtClass: {
     const auto *AS = cast<AttributedStmt>(S);
     const auto *SS = AS->getSubStmt();
+    const auto *RS = dyn_cast<ReturnStmt>(SS);
     MSConstexprContextRAII ConstexprContext(
-        *Info.CurrentCall, hasSpecificAttr<MSConstexprAttr>(AS->getAttrs()) &&
-                               isa<ReturnStmt>(SS));
+        *Info.CurrentCall,
+        hasSpecificAttr<MSConstexprAttr>(AS->getAttrs()) && RS != nullptr);
 
     auto LO = Info.getASTContext().getLangOpts();
-    if (LO.CXXAssumptions && !LO.MSVCCompat) {
-      for (auto *Attr : AS->getAttrs()) {
-        auto *AA = dyn_cast<CXXAssumeAttr>(Attr);
-        if (!AA)
-          continue;
-
-        auto *Assumption = AA->getAssumption();
-        if (Assumption->isValueDependent())
-          return ESR_Failed;
+    for (auto *Attr : AS->getAttrs()) {
+      if (auto *AA = dyn_cast<CXXAssumeAttr>(Attr)) {
+        // This branch handles C++'s [[assume(<EXPR>)]]
+        if (LO.CXXAssumptions && !LO.MSVCCompat) {
+          auto *Assumption = AA->getAssumption();
+          if (Assumption->isValueDependent())
+            return ESR_Failed;
 
-        if (Assumption->HasSideEffects(Info.getASTContext()))
-          continue;
+          if (Assumption->HasSideEffects(Info.getASTContext()))
+            continue;
 
-        bool Value;
-        if (!EvaluateAsBooleanCondition(Assumption, Value, Info))
-          return ESR_Failed;
-        if (!Value) {
-          Info.CCEDiag(Assumption->getExprLoc(),
-                       diag::note_constexpr_assumption_failed);
-          return ESR_Failed;
+          bool Value;
+          if (!EvaluateAsBooleanCondition(Assumption, Value, Info))
+            return ESR_Failed;
+          if (!Value) {
+            Info.CCEDiag(Assumption->getExprLoc(),
+                         diag::note_constexpr_assumption_failed);
+            return ESR_Failed;
+          }
         }
+      } else if (isa<MustTailAttr>(Attr) && RS != nullptr) {
+        // This branch handles [[clang::mustttail]] enforcement on
+        // tail-recursion which is strict and already checked, otherwise it will
+        // fail to compile.
+        Info.EnableTailRecursion(RS);
       }
     }
 
@@ -6514,16 +6608,16 @@ static bool MaybeHandleUnionActiveMemberChange(EvalInfo &Info,
 
 static bool EvaluateCallArg(const ParmVarDecl *PVD, const Expr *Arg,
                             CallRef Call, EvalInfo &Info,
-                            bool NonNull = false) {
+                            CallStackFrame &CallerFrame, bool NonNull = false) {
   LValue LV;
   // Create the parameter slot and register its destruction. For a vararg
   // argument, create a temporary.
   // FIXME: For calling conventions that destroy parameters in the callee,
   // should we consider performing destruction when the function returns
   // instead?
-  APValue &V = PVD ? Info.CurrentCall->createParam(Call, PVD, LV)
-                   : Info.CurrentCall->createTemporary(Arg, Arg->getType(),
-                                                       ScopeKind::Call, LV);
+  APValue &V = PVD ? CallerFrame.createParam(Call, PVD, LV)
+                   : CallerFrame.createTemporary(Arg, Arg->getType(),
+                                                 ScopeKind::Call, LV);
   if (!EvaluateInPlace(V, Info, LV, Arg))
     return false;
 
@@ -6539,8 +6633,8 @@ static bool EvaluateCallArg(const ParmVarDecl *PVD, const Expr *Arg,
 
 /// Evaluate the arguments to a function call.
 static bool EvaluateArgs(ArrayRef<const Expr *> Args, CallRef Call,
-                         EvalInfo &Info, const FunctionDecl *Callee,
-                         bool RightToLeft = false) {
+                         EvalInfo &Info, CallStackFrame &CallerFrame,
+                         const FunctionDecl *Callee, bool RightToLeft = false) {
   bool Success = true;
   llvm::SmallBitVector ForbiddenNullArgs;
   if (Callee->hasAttr<NonNullAttr>()) {
@@ -6563,7 +6657,7 @@ static bool EvaluateArgs(ArrayRef<const Expr *> Args, CallRef Call,
     const ParmVarDecl *PVD =
         Idx < Callee->getNumParams() ? Callee->getParamDecl(Idx) : nullptr;
     bool NonNull = !ForbiddenNullArgs.empty() && ForbiddenNullArgs[Idx];
-    if (!EvaluateCallArg(PVD, Args[Idx], Call, Info, NonNull)) {
+    if (!EvaluateCallArg(PVD, Args[Idx], Call, Info, CallerFrame, NonNull)) {
       // If we're checking for a potential constant expression, evaluate all
       // initializers even if some of them fail.
       if (!Info.noteFailure())
@@ -6650,6 +6744,44 @@ static bool HandleFunctionCall(SourceLocation CallLoc,
   return ESR == ESR_Returned;
 }
 
+static void HandleTailCallTransfer(
+    EvalInfo &Info, const CallExpr *E, const FunctionDecl *Definition,
+    const LValue *This, LValue &ThisVal,
+    llvm::ArrayRef<const clang::Expr *> Args, CallRef Call, Stmt *Body,
+    SmallVector<QualType, 4> &CovariantAdjustmentPath, CallScopeRAII &Scope) {
+  auto &defer = Info.DeferFunctionCall;
+
+  defer.E = E;
+  defer.Definition = Definition;
+  defer.HasThis = This != nullptr;
+  ThisVal.moveInto(defer.ThisVal);
+  defer.Args = Args;
+  defer.Call = Call;
+  defer.Body = Body;
+  defer.CovariantAdjustmentPath = std::move(CovariantAdjustmentPath);
+
+  transferFromCallScope(Scope, defer.ArgumentsStored);
+}
+
+static bool HandleTailCallSetup(
+    EvalInfo &Info, const CallExpr *&E, const FunctionDecl *&Definition,
+    LValue *&This, LValue &ThisVal, llvm::ArrayRef<const clang::Expr *> &Args,
+    CallRef &Call, Stmt *&Body,
+    SmallVector<QualType, 4> &CovariantAdjustmentPath, CallScopeRAII &Scope) {
+  auto &defer = Info.DeferFunctionCall;
+  assert(defer.E != nullptr);
+
+  E = std::exchange(defer.E, nullptr);
+  Definition = defer.Definition;
+  ThisVal.setFrom(Info.Ctx, defer.ThisVal);
+  This = defer.HasThis ? &ThisVal : nullptr;
+  Args = defer.Args;
+  Call = defer.Call;
+  Body = defer.Body;
+  CovariantAdjustmentPath = std::move(defer.CovariantAdjustmentPath);
+  return transferIntoCallScope(Scope, defer.ArgumentsStored);
+}
+
 /// Evaluate a constructor call.
 static bool HandleConstructorCall(const Expr *E, const LValue &This,
                                   CallRef Call,
@@ -6871,7 +7003,7 @@ static bool HandleConstructorCall(const Expr *E, const LValue &This,
                                   EvalInfo &Info, APValue &Result) {
   CallScopeRAII CallScope(Info);
   CallRef Call = Info.CurrentCall->createCall(Definition);
-  if (!EvaluateArgs(Args, Call, Info, Definition))
+  if (!EvaluateArgs(Args, Call, Info, *Info.CurrentCall, Definition))
     return false;
 
   return HandleConstructorCall(E, This, Call, Definition, Info, Result) &&
@@ -8242,6 +8374,13 @@ class ExprEvaluatorBase
     APValue Result;
     if (!handleCallExpr(E, Result, nullptr))
       return false;
+
+    // When our current call is defered as a tail recursion
+    // we can't change result (yet).
+    if (Info.DeferFunctionCall.E != nullptr) {
+      return true;
+    }
+
     return DerivedSuccess(Result, E);
   }
 
@@ -8257,6 +8396,11 @@ class ExprEvaluatorBase
     auto Args = llvm::ArrayRef(E->getArgs(), E->getNumArgs());
     bool HasQualifier = false;
 
+    // Check for tail recursion, before we start evaluating any internal
+    // expression which can steal tail on their own.
+    const bool TailRecursion =
+        std::exchange(Info.TailRecursionReturnStmt, nullptr) != nullptr;
+
     CallRef Call;
 
     // Extract function decl and 'this' pointer from the callee.
@@ -8317,12 +8461,15 @@ class ExprEvaluatorBase
       auto *OCE = dyn_cast<CXXOperatorCallExpr>(E);
       if (OCE && OCE->isAssignmentOp()) {
         assert(Args.size() == 2 && "wrong number of arguments in assignment");
-        Call = Info.CurrentCall->createCall(FD);
         bool HasThis = false;
         if (const auto *MD = dyn_cast<CXXMethodDecl>(FD))
           HasThis = MD->isImplicitObjectMemberFunction();
-        if (!EvaluateArgs(HasThis ? Args.slice(1) : Args, Call, Info, FD,
-                          /*RightToLeft=*/true))
+
+        CallStackFrame &CallOriginFrame =
+            *(TailRecursion ? Info.CurrentCall->Caller : Info.CurrentCall);
+        Call = CallOriginFrame.createCall(FD);
+        if (!EvaluateArgs(HasThis ? Args.slice(1) : Args, Call, Info,
+                          CallOriginFrame, FD, /*RightToLeft = */ true))
           return false;
       }
 
@@ -8404,8 +8551,10 @@ class ExprEvaluatorBase
 
     // Evaluate the arguments now if we've not already done so.
     if (!Call) {
-      Call = Info.CurrentCall->createCall(FD);
-      if (!EvaluateArgs(Args, Call, Info, FD))
+      CallStackFrame &CallOriginFrame =
+          *(TailRecursion ? Info.CurrentCall->Caller : Info.CurrentCall);
+      Call = CallOriginFrame.createCall(FD);
+      if (!EvaluateArgs(Args, Call, Info, CallOriginFrame, FD))
         return false;
     }
 
@@ -8438,11 +8587,40 @@ class ExprEvaluatorBase
     const FunctionDecl *Definition = nullptr;
     Stmt *Body = FD->getBody(Definition);
 
-    if (!CheckConstexprFunction(Info, E->getExprLoc(), FD, Definition, Body) ||
-        !HandleFunctionCall(E->getExprLoc(), Definition, This, E, Args, Call,
+    if (!CheckConstexprFunction(Info, E->getExprLoc(), FD, Definition, Body)) {
+      return false;
+    }
+
+    // If we are doing tail recursion, we need to store everything needed for
+    // the function call. There is always max one tail recursion prepared during
+    // execution of a program.
+    if (TailRecursion) {
+      HandleTailCallTransfer(Info, E, Definition, This, ThisVal, Args, Call,
+                             Body, CovariantAdjustmentPath, CallScope);
+      return true;
+    }
+
+    if (!HandleFunctionCall(E->getExprLoc(), Definition, This, E, Args, Call,
                             Body, Info, Result, ResultSlot))
       return false;
 
+    // If we do tail recursion, we don't have result yet.
+    assert(!Info.TailRecursionReady() || Result.isAbsent());
+
+    // A tail recursion can result in another tail recursion, so we need to loop
+    // here.
+    while (Info.TailRecursionReady()) {
+      if (!HandleTailCallSetup(Info, E, Definition, This, ThisVal, Args, Call,
+                               Body, CovariantAdjustmentPath, CallScope))
+        return false;
+
+      if (!HandleFunctionCall(E->getExprLoc(), Definition, This, E, Args, Call,
+                              Body, Info, Result, ResultSlot))
+        return false;
+    }
+
+    // TODO checkme this is correct
+    // We got out of tail recursion, it was just a normal function.
     if (!CovariantAdjustmentPath.empty() &&
         !HandleCovariantReturnAdjustment(Info, E, Result,
                                          CovariantAdjustmentPath))
@@ -17832,7 +18010,7 @@ bool Expr::EvaluateWithSubstitution(APValue &Value, ASTContext &Ctx,
       break;
     const ParmVarDecl *PVD = Callee->getParamDecl(Idx);
     if ((*I)->isValueDependent() ||
-        !EvaluateCallArg(PVD, *I, Call, Info) ||
+        !EvaluateCallArg(PVD, *I, Call, Info, *Info.CurrentCall) ||
         Info.EvalStatus.HasSideEffects) {
       // If evaluation fails, throw away the argument entirely.
       if (APValue *Slot = Info.getParamSlot(Call, PVD))

@hanickadot
Copy link
Contributor Author

hanickadot commented May 4, 2025

constexpr int deep_test(int remaining) {
	if (remaining == 0) {
		return 42;
	}
	MUST_TAIL return deep_test(remaining - 1);
}

constexpr int result = deep_test(200'000);

With debug build of clang:

clang++ -std=c++26 -c deep.cpp -fconstexpr-depth=1 -DMUST_TAIL="[[clang::musttail]]"

real	0m7.227s
user	0m7.199s
sys	0m0.019s
clang++ -std=c++26 -c deep.cpp -fconstexpr-depth=200000 -DMUST_TAIL=

Segmentation fault: 11

(the recursion crashes around depth 1600)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants