Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -4999,6 +4999,12 @@ class LLVM_ABI TargetLowering : public TargetLoweringBase {
llvm_unreachable("Not Implemented");
}

/// Finds the incoming stack arguments which overlap the given fixed stack
/// object and incorporates their load into the current chain. This prevents
/// an upcoming store from clobbering the stack argument before it's used.
SDValue addTokenForArgument(SDValue Chain, SelectionDAG &DAG,
MachineFrameInfo &MFI, int ClobberedFI) const;

/// Target-specific cleanup for formal ByVal parameters.
virtual void HandleByVal(CCState *, unsigned &, Align) const {}

Expand Down
30 changes: 30 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,36 @@ bool TargetLowering::parametersInCSRMatch(const MachineRegisterInfo &MRI,
return true;
}

SDValue TargetLowering::addTokenForArgument(SDValue Chain, SelectionDAG &DAG,
MachineFrameInfo &MFI,
int ClobberedFI) const {
SmallVector<SDValue, 8> ArgChains;
int64_t FirstByte = MFI.getObjectOffset(ClobberedFI);
int64_t LastByte = FirstByte + MFI.getObjectSize(ClobberedFI) - 1;

// Include the original chain at the beginning of the list. When this is
// used by target LowerCall hooks, this helps legalize find the
// CALLSEQ_BEGIN node.
ArgChains.push_back(Chain);

// Add a chain value for each stack argument corresponding
for (SDNode *U : DAG.getEntryNode().getNode()->users())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is making a big assumption about the chain layout, can't you collect this when actually emitting the argument operations in the first place?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm for sure not the right person to answer that question.

This code is taken from the aarch64 backend originally, which fixed its tail calls in #109943. From what I've been able to gather, the other backends in play worked off of the aarch64 template.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is taken from the aarch64 backend originally, which fixed its tail calls in #109943.

That's ARM not AArch64.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm I got confused then. The addTokenForArgument function is from 11 years ago (and I guess has worked OK over that period):

09cc564

but I know that aarch64 support for tail calls with byval arguments improved in llvm 20 (https://godbolt.org/z/b5KbTPqzY), and I guess I got that mixed up with comments in the riscv tail call implementation saying it's based on the arm code.


Is there a way to still share code when following this suggestion? At least as I understand it it needs some additional bookkeeping on e.g. RISCVMachineFunctionInfo when processing arguments in LowerFormalArguments?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is state local to the selection process, so preferably this should stay out of MachineFunctionInfo. Maybe FunctionLoweringInfo can track this?

if (LoadSDNode *L = dyn_cast<LoadSDNode>(U))
if (FrameIndexSDNode *FI = dyn_cast<FrameIndexSDNode>(L->getBasePtr()))
if (FI->getIndex() < 0) {
int64_t InFirstByte = MFI.getObjectOffset(FI->getIndex());
int64_t InLastByte = InFirstByte;
InLastByte += MFI.getObjectSize(FI->getIndex()) - 1;

if ((InFirstByte <= FirstByte && FirstByte <= InLastByte) ||
(FirstByte <= InFirstByte && InFirstByte <= LastByte))
ArgChains.push_back(SDValue(L, 1));
}

// Build a tokenfactor for all the chains.
return DAG.getNode(ISD::TokenFactor, SDLoc(Chain), MVT::Other, ArgChains);
}

/// Set CallLoweringInfo attribute flags based on a call instruction
/// and called function attributes.
void TargetLoweringBase::ArgListEntry::setAttributes(const CallBase *Call,
Expand Down
31 changes: 0 additions & 31 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9353,37 +9353,6 @@ bool AArch64TargetLowering::isEligibleForTailCallOptimization(
return true;
}

SDValue AArch64TargetLowering::addTokenForArgument(SDValue Chain,
SelectionDAG &DAG,
MachineFrameInfo &MFI,
int ClobberedFI) const {
SmallVector<SDValue, 8> ArgChains;
int64_t FirstByte = MFI.getObjectOffset(ClobberedFI);
int64_t LastByte = FirstByte + MFI.getObjectSize(ClobberedFI) - 1;

// Include the original chain at the beginning of the list. When this is
// used by target LowerCall hooks, this helps legalize find the
// CALLSEQ_BEGIN node.
ArgChains.push_back(Chain);

// Add a chain value for each stack argument corresponding
for (SDNode *U : DAG.getEntryNode().getNode()->users())
if (LoadSDNode *L = dyn_cast<LoadSDNode>(U))
if (FrameIndexSDNode *FI = dyn_cast<FrameIndexSDNode>(L->getBasePtr()))
if (FI->getIndex() < 0) {
int64_t InFirstByte = MFI.getObjectOffset(FI->getIndex());
int64_t InLastByte = InFirstByte;
InLastByte += MFI.getObjectSize(FI->getIndex()) - 1;

if ((InFirstByte <= FirstByte && FirstByte <= InLastByte) ||
(FirstByte <= InFirstByte && InFirstByte <= LastByte))
ArgChains.push_back(SDValue(L, 1));
}

// Build a tokenfactor for all the chains.
return DAG.getNode(ISD::TokenFactor, SDLoc(Chain), MVT::Other, ArgChains);
}

bool AArch64TargetLowering::DoesCalleeRestoreStack(CallingConv::ID CallCC,
bool TailCallOpt) const {
return (CallCC == CallingConv::Fast && TailCallOpt) ||
Expand Down
6 changes: 0 additions & 6 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,12 +630,6 @@ class AArch64TargetLowering : public TargetLowering {
bool
isEligibleForTailCallOptimization(const CallLoweringInfo &CLI) const;

/// Finds the incoming stack arguments which overlap the given fixed stack
/// object and incorporates their load into the current chain. This prevents
/// an upcoming store from clobbering the stack argument before it's used.
SDValue addTokenForArgument(SDValue Chain, SelectionDAG &DAG,
MachineFrameInfo &MFI, int ClobberedFI) const;

bool DoesCalleeRestoreStack(CallingConv::ID CallCC, bool TailCallOpt) const;

void saveVarArgRegisters(CCState &CCInfo, SelectionDAG &DAG, const SDLoc &DL,
Expand Down
34 changes: 0 additions & 34 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1355,40 +1355,6 @@ CCAssignFn *AMDGPUTargetLowering::CCAssignFnForReturn(CallingConv::ID CC,
return AMDGPUCallLowering::CCAssignFnForReturn(CC, IsVarArg);
}

SDValue AMDGPUTargetLowering::addTokenForArgument(SDValue Chain,
SelectionDAG &DAG,
MachineFrameInfo &MFI,
int ClobberedFI) const {
SmallVector<SDValue, 8> ArgChains;
int64_t FirstByte = MFI.getObjectOffset(ClobberedFI);
int64_t LastByte = FirstByte + MFI.getObjectSize(ClobberedFI) - 1;

// Include the original chain at the beginning of the list. When this is
// used by target LowerCall hooks, this helps legalize find the
// CALLSEQ_BEGIN node.
ArgChains.push_back(Chain);

// Add a chain value for each stack argument corresponding
for (SDNode *U : DAG.getEntryNode().getNode()->users()) {
if (LoadSDNode *L = dyn_cast<LoadSDNode>(U)) {
if (FrameIndexSDNode *FI = dyn_cast<FrameIndexSDNode>(L->getBasePtr())) {
if (FI->getIndex() < 0) {
int64_t InFirstByte = MFI.getObjectOffset(FI->getIndex());
int64_t InLastByte = InFirstByte;
InLastByte += MFI.getObjectSize(FI->getIndex()) - 1;

if ((InFirstByte <= FirstByte && FirstByte <= InLastByte) ||
(FirstByte <= InFirstByte && InFirstByte <= LastByte))
ArgChains.push_back(SDValue(L, 1));
}
}
}
}

// Build a tokenfactor for all the chains.
return DAG.getNode(ISD::TokenFactor, SDLoc(Chain), MVT::Other, ArgChains);
}

SDValue AMDGPUTargetLowering::lowerUnhandledCall(CallLoweringInfo &CLI,
SmallVectorImpl<SDValue> &InVals,
StringRef Reason) const {
Expand Down
5 changes: 0 additions & 5 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,6 @@ class AMDGPUTargetLowering : public TargetLowering {
const SmallVectorImpl<SDValue> &OutVals, const SDLoc &DL,
SelectionDAG &DAG) const override;

SDValue addTokenForArgument(SDValue Chain,
SelectionDAG &DAG,
MachineFrameInfo &MFI,
int ClobberedFI) const;

SDValue lowerUnhandledCall(CallLoweringInfo &CLI,
SmallVectorImpl<SDValue> &InVals,
StringRef Reason) const;
Expand Down
112 changes: 72 additions & 40 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23420,6 +23420,7 @@ SDValue RISCVTargetLowering::LowerFormalArguments(
SelectionDAG &DAG, SmallVectorImpl<SDValue> &InVals) const {

MachineFunction &MF = DAG.getMachineFunction();
RISCVMachineFunctionInfo *RVFI = MF.getInfo<RISCVMachineFunctionInfo>();

switch (CallConv) {
default:
Expand Down Expand Up @@ -23548,6 +23549,8 @@ SDValue RISCVTargetLowering::LowerFormalArguments(
continue;
}
InVals.push_back(ArgValue);
if (Ins[InsIdx].Flags.isByVal())
RVFI->addIncomingByValArgs(ArgValue);
}

if (any_of(ArgLocs,
Expand All @@ -23560,7 +23563,6 @@ SDValue RISCVTargetLowering::LowerFormalArguments(
const TargetRegisterClass *RC = &RISCV::GPRRegClass;
MachineFrameInfo &MFI = MF.getFrameInfo();
MachineRegisterInfo &RegInfo = MF.getRegInfo();
RISCVMachineFunctionInfo *RVFI = MF.getInfo<RISCVMachineFunctionInfo>();

// Size of the vararg save area. For now, the varargs save area is either
// zero or large enough to hold a0-a7.
Expand Down Expand Up @@ -23608,6 +23610,8 @@ SDValue RISCVTargetLowering::LowerFormalArguments(
RVFI->setVarArgsSaveSize(VarArgsSaveSize);
}

RVFI->setArgumentStackSize(CCInfo.getStackSize());

// All stores are grouped in one node to allow the matching between
// the size of Ins and InVals. This only happens for vararg functions.
if (!OutChains.empty()) {
Expand All @@ -23629,6 +23633,7 @@ bool RISCVTargetLowering::isEligibleForTailCallOptimization(
auto &Outs = CLI.Outs;
auto &Caller = MF.getFunction();
auto CallerCC = Caller.getCallingConv();
auto *RVFI = MF.getInfo<RISCVMachineFunctionInfo>();

// Exception-handling functions need a special set of instructions to
// indicate a return to the hardware. Tail-calling another function would
Expand All @@ -23638,29 +23643,28 @@ bool RISCVTargetLowering::isEligibleForTailCallOptimization(
if (Caller.hasFnAttribute("interrupt"))
return false;

// Do not tail call opt if the stack is used to pass parameters.
if (CCInfo.getStackSize() != 0)
// If the stack arguments for this call do not fit into our own save area then
// the call cannot be made tail.
if (CCInfo.getStackSize() > RVFI->getArgumentStackSize())
return false;

// Do not tail call opt if any parameters need to be passed indirectly.
// Since long doubles (fp128) and i128 are larger than 2*XLEN, they are
// passed indirectly. So the address of the value will be passed in a
// register, or if not available, then the address is put on the stack. In
// order to pass indirectly, space on the stack often needs to be allocated
// in order to store the value. In this case the CCInfo.getNextStackOffset()
// != 0 check is not enough and we need to check if any CCValAssign ArgsLocs
// are passed CCValAssign::Indirect.
for (auto &VA : ArgLocs)
if (VA.getLocInfo() == CCValAssign::Indirect)
return false;

// Do not tail call opt if either caller or callee uses struct return
// semantics.
auto IsCallerStructRet = Caller.hasStructRetAttr();
auto IsCalleeStructRet = Outs.empty() ? false : Outs[0].Flags.isSRet();
if (IsCallerStructRet || IsCalleeStructRet)
if (IsCallerStructRet != IsCalleeStructRet)
return false;

// Do not tail call opt if caller's and callee's byval arguments do not match.
for (unsigned i = 0, j = 0; i < Outs.size(); i++) {
if (!Outs[i].Flags.isByVal())
continue;
if (j++ >= RVFI->getIncomingByValArgsSize())
return false;
if (RVFI->getIncomingByValArgs(i).getValueType() != Outs[i].ArgVT)
return false;
}

// The callee has to preserve all registers the caller needs to preserve.
const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo();
const uint32_t *CallerPreserved = TRI->getCallPreservedMask(MF, CallerCC);
Expand All @@ -23670,12 +23674,12 @@ bool RISCVTargetLowering::isEligibleForTailCallOptimization(
return false;
}

// Byval parameters hand the function a pointer directly into the stack area
// we want to reuse during a tail call. Working around this *is* possible
// but less efficient and uglier in LowerCall.
for (auto &Arg : Outs)
if (Arg.Flags.isByVal())
return false;
// If the callee takes no arguments then go on to check the results of the
// call.
const MachineRegisterInfo &MRI = MF.getRegInfo();
const SmallVectorImpl<SDValue> &OutVals = CLI.OutVals;
if (!parametersInCSRMatch(MRI, CallerPreserved, ArgLocs, OutVals))
return false;

return true;
}
Expand Down Expand Up @@ -23704,6 +23708,7 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
const CallBase *CB = CLI.CB;

MachineFunction &MF = DAG.getMachineFunction();
RISCVMachineFunctionInfo *RVFI = MF.getInfo<RISCVMachineFunctionInfo>();
MachineFunction::CallSiteInfo CSInfo;

// Set type id for call site info.
Expand Down Expand Up @@ -23738,7 +23743,7 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,

// Create local copies for byval args
SmallVector<SDValue, 8> ByValArgs;
for (unsigned i = 0, e = Outs.size(); i != e; ++i) {
for (unsigned i = 0, j = 0, e = Outs.size(); i != e; ++i) {
ISD::ArgFlagsTy Flags = Outs[i].Flags;
if (!Flags.isByVal())
continue;
Expand All @@ -23747,16 +23752,27 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
unsigned Size = Flags.getByValSize();
Align Alignment = Flags.getNonZeroByValAlign();

int FI =
MF.getFrameInfo().CreateStackObject(Size, Alignment, /*isSS=*/false);
SDValue FIPtr = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout()));
SDValue SizeNode = DAG.getConstant(Size, DL, XLenVT);
SDValue Dst;

Chain = DAG.getMemcpy(Chain, DL, FIPtr, Arg, SizeNode, Alignment,
/*IsVolatile=*/false,
/*AlwaysInline=*/false, /*CI*/ nullptr, IsTailCall,
MachinePointerInfo(), MachinePointerInfo());
ByValArgs.push_back(FIPtr);
if (IsTailCall) {
SDValue CallerArg = RVFI->getIncomingByValArgs(j++);
if (isa<GlobalAddressSDNode>(Arg) || isa<ExternalSymbolSDNode>(Arg) ||
isa<FrameIndexSDNode>(Arg))
Dst = CallerArg;
} else {
int FI =
MF.getFrameInfo().CreateStackObject(Size, Alignment, /*isSS=*/false);
Dst = DAG.getFrameIndex(FI, getPointerTy(DAG.getDataLayout()));
}
if (Dst) {
Chain =
DAG.getMemcpy(Chain, DL, Dst, Arg, SizeNode, Alignment,
/*IsVolatile=*/false,
/*AlwaysInline=*/false, /*CI=*/nullptr, std::nullopt,
MachinePointerInfo(), MachinePointerInfo());
ByValArgs.push_back(Dst);
}
}

if (!IsTailCall)
Expand Down Expand Up @@ -23859,8 +23875,12 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
}

// Use local copy if it is a byval arg.
if (Flags.isByVal())
ArgValue = ByValArgs[j++];
if (Flags.isByVal()) {
if (!IsTailCall || (isa<GlobalAddressSDNode>(ArgValue) ||
isa<ExternalSymbolSDNode>(ArgValue) ||
isa<FrameIndexSDNode>(ArgValue)))
ArgValue = ByValArgs[j++];
}

if (VA.isRegLoc()) {
// Queue up the argument copies and emit them at the end.
Expand All @@ -23871,20 +23891,32 @@ SDValue RISCVTargetLowering::LowerCall(CallLoweringInfo &CLI,
CSInfo.ArgRegPairs.emplace_back(VA.getLocReg(), i);
} else {
assert(VA.isMemLoc() && "Argument not register or memory");
assert(!IsTailCall && "Tail call not allowed if stack is used "
"for passing parameters");
SDValue DstAddr;
MachinePointerInfo DstInfo;
int32_t Offset = VA.getLocMemOffset();

// Work out the address of the stack slot.
if (!StackPtr.getNode())
StackPtr = DAG.getCopyFromReg(Chain, DL, RISCV::X2, PtrVT);
SDValue Address =
DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr,
DAG.getIntPtrConstant(VA.getLocMemOffset(), DL));

if (IsTailCall) {
unsigned OpSize = divideCeil(VA.getValVT().getSizeInBits(), 8);
int FI = MF.getFrameInfo().CreateFixedObject(OpSize, Offset, true);
DstAddr = DAG.getFrameIndex(FI, PtrVT);
DstInfo = MachinePointerInfo::getFixedStack(MF, FI);
// Make sure any stack arguments overlapping with where we're storing
// are loaded before this eventual operation. Otherwise they'll be
// clobbered.
Chain = addTokenForArgument(Chain, DAG, MF.getFrameInfo(), FI);
} else {
SDValue PtrOff = DAG.getIntPtrConstant(Offset, DL);
DstAddr = DAG.getNode(ISD::ADD, DL, PtrVT, StackPtr, PtrOff);
DstInfo = MachinePointerInfo::getStack(MF, Offset);
}

// Emit the store.
MemOpChains.push_back(
DAG.getStore(Chain, DL, ArgValue, Address,
MachinePointerInfo::getStack(MF, VA.getLocMemOffset())));
DAG.getStore(Chain, DL, ArgValue, DstAddr, DstInfo));
}
}

Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/Target/RISCV/RISCVMachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ class RISCVMachineFunctionInfo : public MachineFunctionInfo {
uint64_t RVVPadding = 0;
/// Size of stack frame to save callee saved registers
unsigned CalleeSavedStackSize = 0;

/// ArgumentStackSize - amount of bytes on stack consumed by the arguments
/// being passed on the stack
unsigned ArgumentStackSize = 0;

/// Incoming ByVal arguments
SmallVector<SDValue, 8> IncomingByValArgs;

/// Is there any vector argument or return?
bool IsVectorCall = false;

Expand Down Expand Up @@ -142,6 +150,13 @@ class RISCVMachineFunctionInfo : public MachineFunctionInfo {
unsigned getCalleeSavedStackSize() const { return CalleeSavedStackSize; }
void setCalleeSavedStackSize(unsigned Size) { CalleeSavedStackSize = Size; }

unsigned getArgumentStackSize() const { return ArgumentStackSize; }
void setArgumentStackSize(unsigned size) { ArgumentStackSize = size; }

void addIncomingByValArgs(SDValue Val) { IncomingByValArgs.push_back(Val); }
SDValue &getIncomingByValArgs(int Idx) { return IncomingByValArgs[Idx]; }
unsigned getIncomingByValArgsSize() { return IncomingByValArgs.size(); }

enum class PushPopKind { None = 0, StdExtZcmp, VendorXqccmp };

PushPopKind getPushPopKind(const MachineFunction &MF) const;
Expand Down
Loading