Skip to content

Conversation

@folkertdev
Copy link
Contributor

Basically #168506 but for riscv, so to be clear the hard work here is @heiher 's. I figured we may as well get some extra eyeballs on this from riscv too.

Previously the riscv backend could not handle musttail calls with more arguments than fit in registers, or any explicit byval or sret parameters/return values. Those have now been implemented.

This is part of my push to get more LLVM backends to support byval and sret parameters so that rust can stabilize guaranteed tail call support. See also:

@llvmbot
Copy link
Member

llvmbot commented Dec 3, 2025

@llvm/pr-subscribers-llvm-selectiondag
@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-backend-risc-v

Author: Folkert de Vries (folkertdev)

Changes

Basically #168506 but for riscv, so to be clear the hard work here is @heiher 's. I figured we may as well get some extra eyeballs on this from riscv too.

Previously the riscv backend could not handle musttail calls with more arguments than fit in registers, or any explicit byval or sret parameters/return values. Those have now been implemented.

This is part of my push to get more LLVM backends to support byval and sret parameters so that rust can stabilize guaranteed tail call support. See also:


Patch is 38.45 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170547.diff

6 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+103-40)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.h (+6)
  • (modified) llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h (+15)
  • (modified) llvm/test/CodeGen/RISCV/musttail-call.ll (+5-4)
  • (added) llvm/test/CodeGen/RISCV/musttail.ll (+571)
  • (modified) llvm/test/CodeGen/RISCV/tail-calls.ll (+30-46)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index be53f51afe79f..5d0a0664d9c14 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -23420,6 +23420,7 @@ SDValue RISCVTargetLowering::LowerFormalArguments(
     SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {
 
   MachineFunction &MF = DAG.getMachineFunction();
+  RISCVMachineFunctionInfo *RVFI = MF.getInfo<RISCVMachineFunctionInfo>();
 
   switch (CallConv) {
   default:
@@ -23548,6 +23549,8 @@ SDValue RISCVTargetLowering::LowerFormalArguments(
       continue;
     }
     InVals.push_back(ArgValue);
+    if (Ins[InsIdx].Flags.isByVal())
+      RVFI->addIncomingByValArgs(ArgValue);
   }
 
   if (any_of(ArgLocs,
@@ -23560,7 +23563,6 @@ SDValue RISCVTargetLowering::LowerFormalArguments(
     const TargetRegisterClass *RC = &RISCV::GPRRegClass;
     MachineFrameInfo &MFI = MF.getFrameInfo();
     MachineRegisterInfo &RegInfo = MF.getRegInfo();
-    RISCVMachineFunctionInfo *RVFI = MF.getInfo<RISCVMachineFunctionInfo>();
 
     // Size of the vararg save area. For now, the varargs save area is either
     // zero or large enough to hold a0-a7.
@@ -23608,6 +23610,8 @@ SDValue RISCVTargetLowering::LowerFormalArguments(
     RVFI->setVarArgsSaveSize(VarArgsSaveSize);
   }
 
+  RVFI->setArgumentStackSize(CCInfo.getStackSize());
+
   // All stores are grouped in one node to allow the matching between
   // the size of Ins and InVals. This only happens for vararg functions.
   if (!OutChains.empty()) {
@@ -23629,6 +23633,7 @@ bool RISCVTargetLowering::isEligibleForTailCallOptimization(
   auto &Outs = CLI.Outs;
   auto &Caller = MF.getFunction();
   auto CallerCC = Caller.getCallingConv();
+  auto *RVFI = MF.getInfo<RISCVMachineFunctionInfo>();
 
   // Exception-handling functions need a special set of instructions to
   // indicate a return to the hardware. Tail-calling another function would
@@ -23638,29 +23643,28 @@ bool RISCVTargetLowering::isEligibleForTailCallOptimization(
   if (Caller.hasFnAttribute("interrupt"))
     return false;
 
-  // Do not tail call opt if the stack is used to pass parameters.
-  if (CCInfo.getStackSize() != 0)
+  // If the stack arguments for this call do not fit into our own save area then
+  // the call cannot be made tail.
+  if (CCInfo.getStackSize() > RVFI->getArgumentStackSize())
     return false;
 
-  // Do not tail call opt if any parameters need to be passed indirectly.
-  // Since long doubles (fp128) and i128 are larger than 2*XLEN, they are
-  // passed indirectly. So the address of the value will be passed in a
-  // register, or if not available, then the address is put on the stack. In
-  // order to pass indirectly, space on the stack often needs to be allocated
-  // in order to store the value. In this case the CCInfo.getNextStackOffset()
-  // != 0 check is not enough and we need to check if any CCValAssign ArgsLocs
-  // are passed CCValAssign::Indirect.
-  for (auto &VA : ArgLocs)
-    if (VA.getLocInfo() == CCValAssign::Indirect)
-      return false;
-
   // Do not tail call opt if either caller or callee uses struct return
   // semantics.
   auto IsCallerStructRet = Caller.hasStructRetAttr();
   auto IsCalleeStructRet = Outs.empty() ? false : Outs[0].Flags.isSRet();
-  if (IsCallerStructRet || IsCalleeStructRet)
+  if (IsCallerStructRet != IsCalleeStructRet)
     return false;
 
+  // Do not tail call opt if caller's and callee's byval arguments do not match.
+  for (unsigned i = 0, j = 0; i < Outs.size(); i++) {
+    if (!Outs[i].Flags.isByVal())
+      continue;
+    if (j++ >= RVFI->getIncomingByValArgsSize())
+      return false;
+    if (RVFI->getIncomingByValArgs(i).getValueType() != Outs[i].ArgVT)
+      return false;
+  }
+
   // The callee has to preserve all registers the caller needs to preserve.
   const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo();
   const uint32_t *CallerPreserved = TRI->getCallPreservedMask(MF, CallerCC);
@@ -23670,16 +23674,47 @@ bool RISCVTargetLowering::isEligibleForTailCallOptimization(
       return false;
   }
 
-  // Byval parameters hand the function a pointer directly into the stack area
-  // we want to reuse during a tail call. Working around this *is* possible
-  // but less efficient and uglier in LowerCall.
-  for (auto &Arg : Outs)
-    if (Arg.Flags.isByVal())
-      return false;
+  // If the callee takes no arguments then go on to check the results of the
+  // call.
+  const MachineRegisterInfo &MRI = MF.getRegInfo();
+  const SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
+  if (!parametersInCSRMatch(MRI, CallerPreserved, ArgLocs, OutVals))
+    return false;
 
   return true;
 }
 
+SDValue RISCVTargetLowering::addTokenForArgument(SDValue Chain,
+                                                 SelectionDAG &DAG,
+                                                 MachineFrameInfo &MFI,
+                                                 int ClobberedFI) const {
+  SmallVector<SDValue, 8> ArgChains;
+  int64_t FirstByte = MFI.getObjectOffset(ClobberedFI);
+  int64_t LastByte = FirstByte + MFI.getObjectSize(ClobberedFI) - 1;
+
+  // Include the original chain at the beginning of the list. When this is
+  // used by target LowerCall hooks, this helps legalize find the
+  // CALLSEQ_BEGIN node.
+  ArgChains.push_back(Chain);
+
+  // Add a chain value for each stack argument corresponding
+  for (SDNode *U : DAG.getEntryNode().getNode()->users())
+    if (LoadSDNode *L = dyn_cast<LoadSDNode>(U))
+      if (FrameIndexSDNode *FI = dyn_cast<FrameIndexSDNode>(L->getBasePtr()))
+        if (FI->getIndex() < 0) {
+          int64_t InFirstByte = MFI.getObjectOffset(FI->getIndex());
+          int64_t InLastByte = InFirstByte;
+          InLastByte += MFI.getObjectSize(FI->getIndex()) - 1;
+
+          if ((InFirstByte <= FirstByte && FirstByte <= InLastByte) ||
+              (FirstByte <= InFirstByte && InFirstByte <= LastByte))
+            ArgChains.push_back(SDValue(L, 1));
+        }
+
+  // Build a tokenfactor for all the chains.
+  return DAG.getNode(ISD::TokenFactor, SDLoc(Chain), MVT::Other, ArgChains);
+}
+
 static Align getPrefTypeAlign(EVT VT, SelectionDAG &DAG) {
   return DAG.getDataLayout().getPrefTypeAlign(
       VT.getTypeForEVT(*DAG.getContext()));
@@ -23704,6 +23739,7 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
   const CallBase *CB = CLI.CB;
 
   MachineFunction &MF = DAG.getMachineFunction();
+  RISCVMachineFunctionInfo *RVFI = MF.getInfo<RISCVMachineFunctionInfo>();
   MachineFunction::CallSiteInfo CSInfo;
 
   // Set type id for call site info.
@@ -23738,7 +23774,7 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
 
   // Create local copies for byval args
   SmallVector<SDValue, 8> ByValArgs;
-  for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
+  for (unsigned i = 0, j = 0, e = Outs.size(); i != e; ++i) {
     ISD::ArgFlagsTy Flags = Outs[i].Flags;
     if (!Flags.isByVal())
       continue;
@@ -23747,16 +23783,27 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
     unsigned Size = Flags.getByValSize();
     Align Alignment = Flags.getNonZeroByValAlign();
 
-    int FI =
-        MF.getFrameInfo().CreateStackObject(Size, Alignment, /*isSS=*/false);
-    SDValue FIPtr = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout()));
     SDValue SizeNode = DAG.getConstant(Size, DL, XLenVT);
+    SDValue Dst;
 
-    Chain = DAG.getMemcpy(Chain, DL, FIPtr, Arg, SizeNode, Alignment,
-                          /*IsVolatile=*/false,
-                          /*AlwaysInline=*/false, /*CI*/ nullptr, IsTailCall,
-                          MachinePointerInfo(), MachinePointerInfo());
-    ByValArgs.push_back(FIPtr);
+    if (IsTailCall) {
+      SDValue CallerArg = RVFI->getIncomingByValArgs(j++);
+      if (isa<GlobalAddressSDNode>(Arg) || isa<ExternalSymbolSDNode>(Arg) ||
+          isa<FrameIndexSDNode>(Arg))
+        Dst = CallerArg;
+    } else {
+      int FI =
+          MF.getFrameInfo().CreateStackObject(Size, Alignment, /*isSS=*/false);
+      Dst = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout()));
+    }
+    if (Dst) {
+      Chain =
+          DAG.getMemcpy(Chain, DL, Dst, Arg, SizeNode, Alignment,
+                        /*IsVolatile=*/false,
+                        /*AlwaysInline=*/false, /*CI=*/nullptr, std::nullopt,
+                        MachinePointerInfo(), MachinePointerInfo());
+      ByValArgs.push_back(Dst);
+    }
   }
 
   if (!IsTailCall)
@@ -23859,8 +23906,12 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
     }
 
     // Use local copy if it is a byval arg.
-    if (Flags.isByVal())
-      ArgValue = ByValArgs[j++];
+    if (Flags.isByVal()) {
+      if (!IsTailCall || (isa<GlobalAddressSDNode>(ArgValue) ||
+                          isa<ExternalSymbolSDNode>(ArgValue) ||
+                          isa<FrameIndexSDNode>(ArgValue)))
+        ArgValue = ByValArgs[j++];
+    }
 
     if (VA.isRegLoc()) {
       // Queue up the argument copies and emit them at the end.
@@ -23871,20 +23922,32 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
         CSInfo.ArgRegPairs.emplace_back(VA.getLocReg(), i);
     } else {
       assert(VA.isMemLoc() && "Argument not register or memory");
-      assert(!IsTailCall && "Tail call not allowed if stack is used "
-                            "for passing parameters");
+      SDValue DstAddr;
+      MachinePointerInfo DstInfo;
+      int32_t Offset = VA.getLocMemOffset();
 
       // Work out the address of the stack slot.
       if (!StackPtr.getNode())
         StackPtr = DAG.getCopyFromReg(Chain, DL, RISCV::X2, PtrVT);
-      SDValue Address =
-          DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr,
-                      DAG.getIntPtrConstant(VA.getLocMemOffset(), DL));
+
+      if (IsTailCall) {
+        unsigned OpSize = (VA.getValVT().getSizeInBits() + 7) / 8;
+        int FI = MF.getFrameInfo().CreateFixedObject(OpSize, Offset, true);
+        DstAddr = DAG.getFrameIndex(FI, PtrVT);
+        DstInfo = MachinePointerInfo::getFixedStack(MF, FI);
+        // Make sure any stack arguments overlapping with where we're storing
+        // are loaded before this eventual operation. Otherwise they'll be
+        // clobbered.
+        Chain = addTokenForArgument(Chain, DAG, MF.getFrameInfo(), FI);
+      } else {
+        SDValue PtrOff = DAG.getIntPtrConstant(Offset, DL);
+        DstAddr = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, PtrOff);
+        DstInfo = MachinePointerInfo::getStack(MF, Offset);
+      }
 
       // Emit the store.
       MemOpChains.push_back(
-          DAG.getStore(Chain, DL, ArgValue, Address,
-                       MachinePointerInfo::getStack(MF, VA.getLocMemOffset())));
+          DAG.getStore(Chain, DL, ArgValue, DstAddr, DstInfo));
     }
   }
 
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h
index 9b46936f195e6..0852c512c3e8f 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.h
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h
@@ -74,6 +74,12 @@ class RISCVTargetLowering : public TargetLowering {
   /// Customize the preferred legalization strategy for certain types.
   LegalizeTypeAction getPreferredVectorAction(MVT VT) const override;
 
+  /// Finds the incoming stack arguments which overlap the given fixed stack
+  /// object and incorporates their load into the current chain. This prevents
+  /// an upcoming store from clobbering the stack argument before it's used.
+  SDValue addTokenForArgument(SDValue Chain, SelectionDAG &DAG,
+                              MachineFrameInfo &MFI, int ClobberedFI) const;
+
   bool softPromoteHalfType() const override { return true; }
 
   /// Return the register type for a given MVT, ensuring vectors are treated
diff --git a/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h b/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h
index f9be80feae211..9c2cd708f2784 100644
--- a/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h
+++ b/llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h
@@ -65,6 +65,14 @@ class RISCVMachineFunctionInfo : public MachineFunctionInfo {
   uint64_t RVVPadding = 0;
   /// Size of stack frame to save callee saved registers
   unsigned CalleeSavedStackSize = 0;
+
+  /// ArgumentStackSize - amount of bytes on stack consumed by the arguments
+  /// being passed on the stack
+  unsigned ArgumentStackSize = 0;
+
+  /// Incoming ByVal arguments
+  SmallVector<SDValue, 8> IncomingByValArgs;
+
   /// Is there any vector argument or return?
   bool IsVectorCall = false;
 
@@ -142,6 +150,13 @@ class RISCVMachineFunctionInfo : public MachineFunctionInfo {
   unsigned getCalleeSavedStackSize() const { return CalleeSavedStackSize; }
   void setCalleeSavedStackSize(unsigned Size) { CalleeSavedStackSize = Size; }
 
+  unsigned getArgumentStackSize() const { return ArgumentStackSize; }
+  void setArgumentStackSize(unsigned size) { ArgumentStackSize = size; }
+
+  void addIncomingByValArgs(SDValue Val) { IncomingByValArgs.push_back(Val); }
+  SDValue &getIncomingByValArgs(int Idx) { return IncomingByValArgs[Idx]; }
+  unsigned getIncomingByValArgsSize() { return IncomingByValArgs.size(); }
+
   enum class PushPopKind { None = 0, StdExtZcmp, VendorXqccmp };
 
   PushPopKind getPushPopKind(const MachineFunction &MF) const;
diff --git a/llvm/test/CodeGen/RISCV/musttail-call.ll b/llvm/test/CodeGen/RISCV/musttail-call.ll
index f6ec5307b8bad..a3ac3560378db 100644
--- a/llvm/test/CodeGen/RISCV/musttail-call.ll
+++ b/llvm/test/CodeGen/RISCV/musttail-call.ll
@@ -9,12 +9,13 @@
 ; RUN: not --crash llc -mtriple riscv64-unknown-elf -o - %s \
 ; RUN: 2>&1 | FileCheck %s
 
-%struct.A = type { i32 }
+declare void @callee_musttail()
 
-declare void @callee_musttail(ptr sret(%struct.A) %a)
-define void @caller_musttail(ptr sret(%struct.A) %a) {
+define void @caller_musttail() #0 {
 ; CHECK: LLVM ERROR: failed to perform tail call elimination on a call site marked musttail
 entry:
-  musttail call void @callee_musttail(ptr sret(%struct.A) %a)
+  musttail call void @callee_musttail()
   ret void
 }
+
+attributes #0 = { "interrupt"="machine" }
diff --git a/llvm/test/CodeGen/RISCV/musttail.ll b/llvm/test/CodeGen/RISCV/musttail.ll
new file mode 100644
index 0000000000000..4765fe7a4f233
--- /dev/null
+++ b/llvm/test/CodeGen/RISCV/musttail.ll
@@ -0,0 +1,571 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=riscv32 %s -o - | FileCheck %s --check-prefix=RV32
+; RUN: llc -mtriple=riscv64 %s -o - | FileCheck %s --check-prefix=RV64
+
+declare i32 @many_args_callee(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8, i32 %9)
+
+define i32 @many_args_tail(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8, i32 %9) {
+; RV32-LABEL: many_args_tail:
+; RV32:       # %bb.0:
+; RV32-NEXT:    li a0, 9
+; RV32-NEXT:    li t0, 8
+; RV32-NEXT:    li a1, 1
+; RV32-NEXT:    li a2, 2
+; RV32-NEXT:    li a3, 3
+; RV32-NEXT:    li a4, 4
+; RV32-NEXT:    li a5, 5
+; RV32-NEXT:    li a6, 6
+; RV32-NEXT:    sw a0, 4(sp)
+; RV32-NEXT:    li a7, 7
+; RV32-NEXT:    sw t0, 0(sp)
+; RV32-NEXT:    li a0, 0
+; RV32-NEXT:    tail many_args_callee
+;
+; RV64-LABEL: many_args_tail:
+; RV64:       # %bb.0:
+; RV64-NEXT:    li a0, 9
+; RV64-NEXT:    li t0, 8
+; RV64-NEXT:    li a1, 1
+; RV64-NEXT:    li a2, 2
+; RV64-NEXT:    li a3, 3
+; RV64-NEXT:    li a4, 4
+; RV64-NEXT:    li a5, 5
+; RV64-NEXT:    li a6, 6
+; RV64-NEXT:    sd a0, 8(sp)
+; RV64-NEXT:    li a7, 7
+; RV64-NEXT:    sd t0, 0(sp)
+; RV64-NEXT:    li a0, 0
+; RV64-NEXT:    tail many_args_callee
+  %ret = tail call i32 @many_args_callee(i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9)
+  ret i32 %ret
+}
+
+define i32 @many_args_musttail(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8, i32 %9) {
+; RV32-LABEL: many_args_musttail:
+; RV32:       # %bb.0:
+; RV32-NEXT:    li a0, 9
+; RV32-NEXT:    li t0, 8
+; RV32-NEXT:    li a1, 1
+; RV32-NEXT:    li a2, 2
+; RV32-NEXT:    li a3, 3
+; RV32-NEXT:    li a4, 4
+; RV32-NEXT:    li a5, 5
+; RV32-NEXT:    li a6, 6
+; RV32-NEXT:    sw a0, 4(sp)
+; RV32-NEXT:    li a7, 7
+; RV32-NEXT:    sw t0, 0(sp)
+; RV32-NEXT:    li a0, 0
+; RV32-NEXT:    tail many_args_callee
+;
+; RV64-LABEL: many_args_musttail:
+; RV64:       # %bb.0:
+; RV64-NEXT:    li a0, 9
+; RV64-NEXT:    li t0, 8
+; RV64-NEXT:    li a1, 1
+; RV64-NEXT:    li a2, 2
+; RV64-NEXT:    li a3, 3
+; RV64-NEXT:    li a4, 4
+; RV64-NEXT:    li a5, 5
+; RV64-NEXT:    li a6, 6
+; RV64-NEXT:    sd a0, 8(sp)
+; RV64-NEXT:    li a7, 7
+; RV64-NEXT:    sd t0, 0(sp)
+; RV64-NEXT:    li a0, 0
+; RV64-NEXT:    tail many_args_callee
+  %ret = musttail call i32 @many_args_callee(i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9)
+  ret i32 %ret
+}
+
+; This function has more arguments than it's tail-callee. This isn't valid for
+; the musttail attribute, but can still be tail-called as a non-guaranteed
+; optimisation, because the outgoing arguments to @many_args_callee fit in the
+; stack space allocated by the caller of @more_args_tail.
+define i32 @more_args_tail(i32 %0, i32 %1, i32 %2, i32 %3, i32 %4, i32 %5, i32 %6, i32 %7, i32 %8, i32 %9) {
+; RV32-LABEL: more_args_tail:
+; RV32:       # %bb.0:
+; RV32-NEXT:    li a0, 9
+; RV32-NEXT:    li t0, 8
+; RV32-NEXT:    li a1, 1
+; RV32-NEXT:    li a2, 2
+; RV32-NEXT:    li a3, 3
+; RV32-NEXT:    li a4, 4
+; RV32-NEXT:    li a5, 5
+; RV32-NEXT:    li a6, 6
+; RV32-NEXT:    sw a0, 4(sp)
+; RV32-NEXT:    li a7, 7
+; RV32-NEXT:    sw t0, 0(sp)
+; RV32-NEXT:    li a0, 0
+; RV32-NEXT:    tail many_args_callee
+;
+; RV64-LABEL: more_args_tail:
+; RV64:       # %bb.0:
+; RV64-NEXT:    li a0, 9
+; RV64-NEXT:    li t0, 8
+; RV64-NEXT:    li a1, 1
+; RV64-NEXT:    li a2, 2
+; RV64-NEXT:    li a3, 3
+; RV64-NEXT:    li a4, 4
+; RV64-NEXT:    li a5, 5
+; RV64-NEXT:    li a6, 6
+; RV64-NEXT:    sd a0, 8(sp)
+; RV64-NEXT:    li a7, 7
+; RV64-NEXT:    sd t0, 0(sp)
+; RV64-NEXT:    li a0, 0
+; RV64-NEXT:    tail many_args_callee
+  %ret = tail call i32 @many_args_callee(i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9)
+  ret i32 %ret
+}
+
+; Again, this isn't valid for musttail, but can be tail-called in practice
+; because the stack size is the same.
+define i32 @different_args_tail_32bit(i64 %0, i64 %1, i64 %2, i64 %3, i64 %4) {
+; RV32-LABEL: different_args_tail_32bit:
+; RV32:       # %bb.0:
+; RV32-NEXT:    li a0, 9
+; RV32-NEXT:    li t0, 8
+; RV32-NEXT:    li a1, 1
+; RV32-NEXT:    li a2, 2
+; RV32-NEXT:    li a3, 3
+; RV32-NEXT:    li a4, 4
+; RV32-NEXT:    li a5, 5
+; RV32-NEXT:    li a6, 6
+; RV32-NEXT:    sw a0, 4(sp)
+; RV32-NEXT:    li a7, 7
+; RV32-NEXT:    sw t0, 0(sp)
+; RV32-NEXT:    li a0, 0
+; RV32-NEXT:    tail many_args_callee
+;
+; RV64-LABEL: different_args_tail_32bit:
+; RV64:       # %bb.0:
+; RV64-NEXT:    addi sp, sp, -32
+; RV64-NEXT:    .cfi_def_cfa_offset 32
+; RV64-NEXT:    sd ra, 24(sp) # 8-byte Folded Spill
+; RV64-NEXT:    .cfi_offset ra, -8
+; RV64-NEXT:    li a0, 9
+; RV64-NEXT:    li t0, 8
+; RV64-NEXT:    li a1, 1
+; RV64-NEXT:    li a2, 2
+; RV64-NEXT:    li a3, 3
+; RV64-NEXT:    li a4, 4
+; RV64-NEXT:    li a5, 5
+; RV64-NEXT:    li a6, 6
+; RV64-NEXT:    li a7, 7
+; RV64-NEXT:    sd t0, 0(sp)
+; RV64-NEXT:    sd a0, 8(sp)
+; RV64-NEXT:    li a0, 0
+; RV64-NEXT:    call many_args_callee
+; RV64-NEXT:    ld ra, 24(sp) # 8-byte Folded Reload
+; RV64-NEXT:    .cfi_restore ra
+; RV64-NEXT:    addi sp, sp, 32
+; RV64-NEXT:    .cfi_def_cfa_offset 0
+; RV64-NEXT:    ret
+  %ret = tail call i32 @many_args_callee(i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9)
+  ret i32 %ret
+}
+
+define i32 @different_args_tail_64bit(i128 %0, i128 %1, i128 %2, i128 %3, i128 %4) {
+; RV32-LABEL: different_args_tail_64bit:
+; RV32:       # %bb.0:
+; RV32-NEXT:    addi sp, sp, -16
+; RV32-NEXT:    .cfi_def_cfa_offset 16
+; RV32-NEXT:    sw ra, 12(sp) # 4-byte Folded Spill
+; RV32-NEXT:    .cfi_offset ra, -4
+; RV32-NEXT:    li a0, 9
+; RV32-NEXT:    li t0, 8
+; RV32-NEXT:    li a1, 1
+; RV32-NEXT:    li a2, 2
+; RV32-NEXT:    li a3, 3
+; RV32-NEXT:    li a4, 4
+; RV32-NEXT:    li a5, 5
+; RV32-NEXT:    li a6, 6
+; RV32-NEXT:    li a7, 7
+; RV32-NEXT:    sw t0, 0(sp)
+; RV32-NEXT...
[truncated]

Comment on lines 23673 to 23678
@@ -23678,29 +23643,28 @@ bool RISCVTargetLowering::isEligibleForTailCallOptimization(
if (Caller.hasFnAttribute("interrupt"))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on the updated tests it seems fine in terms of efficiency.

@folkertdev folkertdev force-pushed the riscv-tail-call-based-on-loongarch branch from f7002a2 to ef13b0e Compare December 3, 2025 21:32
ArgChains.push_back(Chain);

// Add a chain value for each stack argument corresponding
for (SDNode *U : DAG.getEntryNode().getNode()->users())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is making a big assumption about the chain layout, can't you collect this when actually emitting the argument operations in the first place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm for sure not the right person to answer that question.

This code is taken from the aarch64 backend originally, which fixed its tail calls in #109943. From what I've been able to gather, the other backends in play worked off of the aarch64 template.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is taken from the aarch64 backend originally, which fixed its tail calls in #109943.

That's ARM not AArch64.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I got confused then. The addTokenForArgument function is from 11 years ago (and I guess has worked OK over that period):

09cc564

but I know that aarch64 support for tail calls with byval arguments improved in llvm 20 (https://godbolt.org/z/b5KbTPqzY), and I guess I got that mixed up with comments in the riscv tail call implementation saying it's based on the arm code.


Is there a way to still share code when following this suggestion? At least as I understand it it needs some additional bookkeeping on e.g. RISCVMachineFunctionInfo when processing arguments in LowerFormalArguments?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is state local to the selection process, so preferably this should stay out of MachineFunctionInfo. Maybe FunctionLoweringInfo can track this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants