Skip to content

Commit 790c596

Browse files
committed
[AMDGPU] Tail call support for whole wave functions
Support tail calls to whole wave functions (trivial) and from whole wave functions (slightly more involved because we need a new pseudo for the tail call return, that patches up the EXEC mask). Move the expansion of whole wave function return pseudos (regular and tail call returns) to prolog epilog insertion, since that's where we patch up the EXEC mask. Unnecessary register spills will be dealt with in a future patch.
1 parent 8dc9461 commit 790c596

14 files changed

+2470
-41
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7999,13 +7999,18 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
79997999
}
80008000
case Intrinsic::amdgcn_call_whole_wave: {
80018001
TargetLowering::ArgListTy Args;
8002+
bool isTailCall = I.isTailCall();
80028003

80038004
// The first argument is the callee. Skip it when assembling the call args.
80048005
TargetLowering::ArgListEntry Arg;
80058006
for (unsigned Idx = 1; Idx < I.arg_size(); ++Idx) {
80068007
Arg.Node = getValue(I.getArgOperand(Idx));
80078008
Arg.Ty = I.getArgOperand(Idx)->getType();
80088009
Arg.setAttributes(&I, Idx);
8010+
8011+
if (Arg.IsSRet && isa<Instruction>(I.getArgOperand(Idx)))
8012+
isTailCall = false;
8013+
80098014
Args.push_back(Arg);
80108015
}
80118016

@@ -8020,7 +8025,7 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
80208025
.setChain(getRoot())
80218026
.setCallee(CallingConv::AMDGPU_Gfx_WholeWave, I.getType(),
80228027
getValue(I.getArgOperand(0)), std::move(Args))
8023-
.setTailCall(false)
8028+
.setTailCall(isTailCall && canTailCall(I))
80248029
.setIsPreallocated(
80258030
I.countOperandBundlesOfType(LLVMContext::OB_preallocated) != 0)
80268031
.setConvergent(I.isConvergent())
@@ -8901,6 +8906,29 @@ SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI,
89018906
return Result;
89028907
}
89038908

8909+
bool SelectionDAGBuilder::canTailCall(const CallBase &CB) const {
8910+
bool isMustTailCall = CB.isMustTailCall();
8911+
8912+
// Avoid emitting tail calls in functions with the disable-tail-calls
8913+
// attribute.
8914+
auto *Caller = CB.getParent()->getParent();
8915+
if (Caller->getFnAttribute("disable-tail-calls").getValueAsString() ==
8916+
"true" &&
8917+
!isMustTailCall)
8918+
return false;
8919+
8920+
// We can't tail call inside a function with a swifterror argument. Lowering
8921+
// does not support this yet. It would have to move into the swifterror
8922+
// register before the call.
8923+
if (DAG.getTargetLoweringInfo().supportSwiftError() &&
8924+
Caller->getAttributes().hasAttrSomewhere(Attribute::SwiftError))
8925+
return false;
8926+
8927+
// Check if target-independent constraints permit a tail call here.
8928+
// Target-dependent constraints are checked within TLI->LowerCallTo.
8929+
return isInTailCallPosition(CB, DAG.getTarget());
8930+
}
8931+
89048932
void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
89058933
bool isTailCall, bool isMustTailCall,
89068934
const BasicBlock *EHPadBB,
@@ -8915,21 +8943,8 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
89158943
const Value *SwiftErrorVal = nullptr;
89168944
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
89178945

8918-
if (isTailCall) {
8919-
// Avoid emitting tail calls in functions with the disable-tail-calls
8920-
// attribute.
8921-
auto *Caller = CB.getParent()->getParent();
8922-
if (Caller->getFnAttribute("disable-tail-calls").getValueAsString() ==
8923-
"true" && !isMustTailCall)
8924-
isTailCall = false;
8925-
8926-
// We can't tail call inside a function with a swifterror argument. Lowering
8927-
// does not support this yet. It would have to move into the swifterror
8928-
// register before the call.
8929-
if (TLI.supportSwiftError() &&
8930-
Caller->getAttributes().hasAttrSomewhere(Attribute::SwiftError))
8931-
isTailCall = false;
8932-
}
8946+
if (isTailCall)
8947+
isTailCall = canTailCall(CB);
89338948

89348949
for (auto I = CB.arg_begin(), E = CB.arg_end(); I != E; ++I) {
89358950
TargetLowering::ArgListEntry Entry;
@@ -8974,11 +8989,6 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
89748989
Args.push_back(Entry);
89758990
}
89768991

8977-
// Check if target-independent constraints permit a tail call here.
8978-
// Target-dependent constraints are checked within TLI->LowerCallTo.
8979-
if (isTailCall && !isInTailCallPosition(CB, DAG.getTarget()))
8980-
isTailCall = false;
8981-
89828992
// Disable tail calls if there is an swifterror argument. Targets have not
89838993
// been updated to support tail calls.
89848994
if (TLI.supportSwiftError() && SwiftErrorVal)

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,10 @@ class SelectionDAGBuilder {
408408
bool IsMustTailCall, const BasicBlock *EHPadBB = nullptr,
409409
const TargetLowering::PtrAuthInfo *PAI = nullptr);
410410

411+
// Check some of the target-independent constraints for tail calls. This does
412+
// not iterate over the call arguments.
413+
bool canTailCall(const CallBase &CB) const;
414+
411415
// Lower range metadata from 0 to N to assert zext to an integer of nearest
412416
// floor power of two.
413417
SDValue lowerRangeToAssertZExt(SelectionDAG &DAG, const Instruction &I,

llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -993,8 +993,14 @@ static unsigned getCallOpcode(const MachineFunction &CallerF, bool IsIndirect,
993993
return IsWave32 ? AMDGPU::SI_CS_CHAIN_TC_W32 : AMDGPU::SI_CS_CHAIN_TC_W64;
994994
}
995995

996-
return CC == CallingConv::AMDGPU_Gfx ? AMDGPU::SI_TCRETURN_GFX :
997-
AMDGPU::SI_TCRETURN;
996+
if (CallerF.getFunction().getCallingConv() ==
997+
CallingConv::AMDGPU_Gfx_WholeWave)
998+
return AMDGPU::SI_TCRETURN_GFX_WholeWave;
999+
1000+
if (CC == CallingConv::AMDGPU_Gfx || CC == CallingConv::AMDGPU_Gfx_WholeWave)
1001+
return AMDGPU::SI_TCRETURN_GFX;
1002+
1003+
return AMDGPU::SI_TCRETURN;
9981004
}
9991005

10001006
// Add operands to call instruction to track the callee.
@@ -1273,6 +1279,13 @@ bool AMDGPUCallLowering::lowerTailCall(
12731279
unsigned Opc = getCallOpcode(MF, Info.Callee.isReg(), /*IsTailCall*/ true,
12741280
ST.isWave32(), CalleeCC, IsDynamicVGPRChainCall);
12751281
auto MIB = MIRBuilder.buildInstrNoInsert(Opc);
1282+
1283+
if (FuncInfo->isWholeWaveFunction())
1284+
addOriginalExecToReturn(MF, MIB);
1285+
1286+
// Keep track of the index of the next operand to be added to the call
1287+
unsigned CalleeIdx = MIB->getNumOperands();
1288+
12761289
if (!addCallTargetOperands(MIB, MIRBuilder, Info, IsDynamicVGPRChainCall))
12771290
return false;
12781291

@@ -1390,7 +1403,7 @@ bool AMDGPUCallLowering::lowerTailCall(
13901403
// If we have -tailcallopt, we need to adjust the stack. We'll do the call
13911404
// sequence start and end here.
13921405
if (!IsSibCall) {
1393-
MIB->getOperand(1).setImm(FPDiff);
1406+
MIB->getOperand(CalleeIdx + 1).setImm(FPDiff);
13941407
CallSeqStart.addImm(NumBytes).addImm(0);
13951408
// End the call sequence *before* emitting the call. Normally, we would
13961409
// tidy the frame up after the call. However, here, we've laid out the
@@ -1402,16 +1415,24 @@ bool AMDGPUCallLowering::lowerTailCall(
14021415
// Now we can add the actual call instruction to the correct basic block.
14031416
MIRBuilder.insertInstr(MIB);
14041417

1418+
// If this is a whole wave tail call, we need to constrain the register for
1419+
// the original EXEC.
1420+
if (MIB->getOpcode() == AMDGPU::SI_TCRETURN_GFX_WholeWave) {
1421+
MIB->getOperand(0).setReg(
1422+
constrainOperandRegClass(MF, *TRI, MRI, *TII, *ST.getRegBankInfo(),
1423+
*MIB, MIB->getDesc(), MIB->getOperand(0), 0));
1424+
}
1425+
14051426
// If Callee is a reg, since it is used by a target specific
14061427
// instruction, it must have a register class matching the
14071428
// constraint of that instruction.
14081429

14091430
// FIXME: We should define regbankselectable call instructions to handle
14101431
// divergent call targets.
1411-
if (MIB->getOperand(0).isReg()) {
1412-
MIB->getOperand(0).setReg(
1413-
constrainOperandRegClass(MF, *TRI, MRI, *TII, *ST.getRegBankInfo(),
1414-
*MIB, MIB->getDesc(), MIB->getOperand(0), 0));
1432+
if (MIB->getOperand(CalleeIdx).isReg()) {
1433+
MIB->getOperand(CalleeIdx).setReg(constrainOperandRegClass(
1434+
MF, *TRI, MRI, *TII, *ST.getRegBankInfo(), *MIB, MIB->getDesc(),
1435+
MIB->getOperand(CalleeIdx), CalleeIdx));
14151436
}
14161437

14171438
MF.getFrameInfo().setHasTailCall();

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5736,6 +5736,7 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
57365736
NODE_NAME_CASE(CALL)
57375737
NODE_NAME_CASE(TC_RETURN)
57385738
NODE_NAME_CASE(TC_RETURN_GFX)
5739+
NODE_NAME_CASE(TC_RETURN_GFX_WholeWave)
57395740
NODE_NAME_CASE(TC_RETURN_CHAIN)
57405741
NODE_NAME_CASE(TC_RETURN_CHAIN_DVGPR)
57415742
NODE_NAME_CASE(TRAP)

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ enum NodeType : unsigned {
402402
CALL,
403403
TC_RETURN,
404404
TC_RETURN_GFX,
405+
TC_RETURN_GFX_WholeWave,
405406
TC_RETURN_CHAIN,
406407
TC_RETURN_CHAIN_DVGPR,
407408
TRAP,

llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ def AMDGPUtc_return_gfx: SDNode<"AMDGPUISD::TC_RETURN_GFX", AMDGPUTCReturnTP,
9494
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
9595
>;
9696

97+
def AMDGPUtc_return_gfx_ww: SDNode<"AMDGPUISD::TC_RETURN_GFX_WholeWave", AMDGPUTCReturnTP,
98+
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
99+
>;
100+
97101
def AMDGPUtc_return_chain: SDNode<"AMDGPUISD::TC_RETURN_CHAIN",
98102
SDTypeProfile<0, -1, [SDTCisPtrTy<0>]>,
99103
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]

llvm/lib/Target/AMDGPU/SIFrameLowering.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,9 +1125,18 @@ void SIFrameLowering::emitCSRSpillRestores(
11251125
RestoreWWMRegisters(WWMCalleeSavedRegs);
11261126

11271127
// The original EXEC is the first operand of the return instruction.
1128-
const MachineInstr &Return = MBB.instr_back();
1129-
assert(Return.getOpcode() == AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN &&
1130-
"Unexpected return inst");
1128+
MachineInstr &Return = MBB.instr_back();
1129+
unsigned Opcode = Return.getOpcode();
1130+
switch (Opcode) {
1131+
case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN:
1132+
Opcode = AMDGPU::SI_RETURN;
1133+
break;
1134+
case AMDGPU::SI_TCRETURN_GFX_WholeWave:
1135+
Opcode = AMDGPU::SI_TCRETURN_GFX;
1136+
break;
1137+
default:
1138+
llvm_unreachable("Unexpected return inst");
1139+
}
11311140
Register OrigExec = Return.getOperand(0).getReg();
11321141

11331142
if (!WWMScratchRegs.empty()) {
@@ -1141,6 +1150,11 @@ void SIFrameLowering::emitCSRSpillRestores(
11411150
// Restore original EXEC.
11421151
unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
11431152
BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addReg(OrigExec);
1153+
1154+
// Drop the first operand and update the opcode.
1155+
Return.removeOperand(0);
1156+
Return.setDesc(TII->get(Opcode));
1157+
11441158
return;
11451159
}
11461160

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4163,6 +4163,11 @@ SDValue SITargetLowering::LowerCall(CallLoweringInfo &CLI,
41634163
break;
41644164
}
41654165

4166+
// If the caller is a whole wave function, we need to use a special opcode
4167+
// so we can patch up EXEC.
4168+
if (Info->isWholeWaveFunction())
4169+
OPC = AMDGPUISD::TC_RETURN_GFX_WholeWave;
4170+
41664171
return DAG.getNode(OPC, DL, MVT::Other, Ops);
41674172
}
41684173

@@ -5904,6 +5909,7 @@ SITargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
59045909
MI.eraseFromParent();
59055910
return SplitBB;
59065911
}
5912+
case AMDGPU::SI_TCRETURN_GFX_WholeWave:
59075913
case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN: {
59085914
assert(MFI->isWholeWaveFunction());
59095915

llvm/lib/Target/AMDGPU/SIInstrInfo.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2543,7 +2543,6 @@ bool SIInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
25432543
MI.setDesc(get(ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64));
25442544
break;
25452545
}
2546-
case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN:
25472546
case AMDGPU::SI_RETURN: {
25482547
const MachineFunction *MF = MBB.getParent();
25492548
const GCNSubtarget &ST = MF->getSubtarget<GCNSubtarget>();

llvm/lib/Target/AMDGPU/SIInstructions.td

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,33 @@ def SI_WHOLE_WAVE_FUNC_RETURN : SPseudoInstSI <
670670
def : GCNPat<
671671
(AMDGPUwhole_wave_return), (SI_WHOLE_WAVE_FUNC_RETURN (i1 (IMPLICIT_DEF)))>;
672672

673+
// Restores the previous EXEC and otherwise behaves entirely like a SI_TCRETURN.
674+
// This is used for tail calls *from* a whole wave function. Tail calls to
675+
// a whole wave function may use the usual opcodes, depending on the calling
676+
// convention of the caller.
677+
def SI_TCRETURN_GFX_WholeWave : SPseudoInstSI <
678+
(outs),
679+
(ins SReg_1:$orig_exec, Gfx_CCR_SGPR_64:$src0, unknown:$callee, i32imm:$fpdiff)> {
680+
let isCall = 1;
681+
let isTerminator = 1;
682+
let isReturn = 1;
683+
let isBarrier = 1;
684+
let UseNamedOperandTable = 1;
685+
let SchedRW = [WriteBranch];
686+
let isConvergent = 1;
687+
688+
// We're going to use custom handling to set the $orig_exec to the correct value.
689+
let usesCustomInserter = 1;
690+
}
691+
692+
// Generate a SI_TCRETURN_GFX_WholeWave pseudo with a placeholder for its
693+
// argument. It will be filled in by the custom inserter.
694+
def : GCNPat<
695+
(AMDGPUtc_return_gfx_ww i64:$src0, tglobaladdr:$callee, i32:$fpdiff),
696+
(SI_TCRETURN_GFX_WholeWave (i1 (IMPLICIT_DEF)), Gfx_CCR_SGPR_64:$src0,
697+
tglobaladdr:$callee, i32:$fpdiff)>;
698+
699+
673700
// Return for returning shaders to a shader variant epilog.
674701
def SI_RETURN_TO_EPILOG : SPseudoInstSI <
675702
(outs), (ins variable_ops), [(AMDGPUreturn_to_epilog)]> {

0 commit comments

Comments
 (0)