Skip to content

Commit 20d8398

Browse files
rovkaarsenmshiltian
authored
[AMDGPU] ISel & PEI for whole wave functions (#145858)
Whole wave functions are functions that will run with a full EXEC mask. They will not be invoked directly, but instead will be launched by way of a new intrinsic, `llvm.amdgcn.call.whole.wave` (to be added in a future patch). These functions are meant as an alternative to the `llvm.amdgcn.init.whole.wave` or `llvm.amdgcn.strict.wwm` intrinsics. Whole wave functions will set EXEC to -1 in the prologue and restore the original value of EXEC in the epilogue. They must have a special first argument, `i1 %active`, that is going to be mapped to EXEC. They may have either the default calling convention or amdgpu_gfx. The inactive lanes need to be preserved for all registers used, active lanes only for the CSRs. At the IR level, arguments to a whole wave function (other than `%active`) contain poison in their inactive lanes. Likewise, the return value for the inactive lanes is poison. This patch contains the following work: * 2 new pseudos, SI_SETUP_WHOLE_WAVE_FUNC and SI_WHOLE_WAVE_FUNC_RETURN used for managing the EXEC mask. SI_SETUP_WHOLE_WAVE_FUNC will return a SReg_1 representing `%active`, which needs to be passed into SI_WHOLE_WAVE_FUNC_RETURN. * SelectionDAG support for generating these 2 new pseudos and the special handling of %active. Since the return may be in a different basic block, it's difficult to add the virtual reg for %active to SI_WHOLE_WAVE_FUNC_RETURN, so we initially generate an IMPLICIT_DEF which is later replaced via a custom inserter. * Expansion of the 2 pseudos during prolog/epilog insertion. PEI also marks any used VGPRs as WWM registers, which are then spilled and restored with the usual logic. Future patches will include the `llvm.amdgcn.call.whole.wave` intrinsic and a lot of optimization work (especially in order to reduce spills around function calls). --------- Co-authored-by: Matt Arsenault <[email protected]> Co-authored-by: Shilei Tian <[email protected]>
1 parent e87d390 commit 20d8398

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+3539
-23
lines changed

llvm/docs/AMDGPUUsage.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,6 +1844,20 @@ The AMDGPU backend supports the following calling conventions:
18441844
..TODO::
18451845
Describe.
18461846

1847+
``amdgpu_gfx_whole_wave`` Used for AMD graphics targets. Functions with this calling convention
1848+
cannot be used as entry points. They must have an i1 as the first argument,
1849+
which will be mapped to the value of EXEC on entry into the function. Other
1850+
arguments will contain poison in their inactive lanes. Similarly, the return
1851+
value for the inactive lanes is poison.
1852+
1853+
The function will run with all lanes enabled, i.e. EXEC will be set to -1 in the
1854+
prologue and restored to its original value in the epilogue. The inactive lanes
1855+
will be preserved for all the registers used by the function. Active lanes only
1856+
will only be preserved for the callee saved registers.
1857+
1858+
In all other respects, functions with this calling convention behave like
1859+
``amdgpu_gfx`` functions.
1860+
18471861
``amdgpu_gs`` Used for Mesa/AMDPAL geometry shaders.
18481862
..TODO::
18491863
Describe.

llvm/include/llvm/AsmParser/LLToken.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ enum Kind {
181181
kw_amdgpu_cs_chain_preserve,
182182
kw_amdgpu_kernel,
183183
kw_amdgpu_gfx,
184+
kw_amdgpu_gfx_whole_wave,
184185
kw_tailcc,
185186
kw_m68k_rtdcc,
186187
kw_graalcc,

llvm/include/llvm/IR/CallingConv.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,9 @@ namespace CallingConv {
284284
RISCV_VLSCall_32768 = 122,
285285
RISCV_VLSCall_65536 = 123,
286286

287+
// Calling convention for AMDGPU whole wave functions.
288+
AMDGPU_Gfx_WholeWave = 124,
289+
287290
/// The highest possible ID. Must be some 2^k - 1.
288291
MaxID = 1023
289292
};
@@ -294,8 +297,13 @@ namespace CallingConv {
294297
/// directly or indirectly via a call-like instruction.
295298
constexpr bool isCallableCC(CallingConv::ID CC) {
296299
switch (CC) {
300+
// Called with special intrinsics:
301+
// llvm.amdgcn.cs.chain
297302
case CallingConv::AMDGPU_CS_Chain:
298303
case CallingConv::AMDGPU_CS_ChainPreserve:
304+
// llvm.amdgcn.call.whole.wave
305+
case CallingConv::AMDGPU_Gfx_WholeWave:
306+
// Hardware entry points:
299307
case CallingConv::AMDGPU_CS:
300308
case CallingConv::AMDGPU_ES:
301309
case CallingConv::AMDGPU_GS:

llvm/lib/AsmParser/LLLexer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,7 @@ lltok::Kind LLLexer::LexIdentifier() {
679679
KEYWORD(amdgpu_cs_chain_preserve);
680680
KEYWORD(amdgpu_kernel);
681681
KEYWORD(amdgpu_gfx);
682+
KEYWORD(amdgpu_gfx_whole_wave);
682683
KEYWORD(tailcc);
683684
KEYWORD(m68k_rtdcc);
684685
KEYWORD(graalcc);

llvm/lib/AsmParser/LLParser.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2272,6 +2272,9 @@ bool LLParser::parseOptionalCallingConv(unsigned &CC) {
22722272
CC = CallingConv::AMDGPU_CS_ChainPreserve;
22732273
break;
22742274
case lltok::kw_amdgpu_kernel: CC = CallingConv::AMDGPU_KERNEL; break;
2275+
case lltok::kw_amdgpu_gfx_whole_wave:
2276+
CC = CallingConv::AMDGPU_Gfx_WholeWave;
2277+
break;
22752278
case lltok::kw_tailcc: CC = CallingConv::Tail; break;
22762279
case lltok::kw_m68k_rtdcc: CC = CallingConv::M68k_RTD; break;
22772280
case lltok::kw_graalcc: CC = CallingConv::GRAAL; break;

llvm/lib/IR/AsmWriter.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,9 @@ static void PrintCallingConv(unsigned cc, raw_ostream &Out) {
404404
break;
405405
case CallingConv::AMDGPU_KERNEL: Out << "amdgpu_kernel"; break;
406406
case CallingConv::AMDGPU_Gfx: Out << "amdgpu_gfx"; break;
407+
case CallingConv::AMDGPU_Gfx_WholeWave:
408+
Out << "amdgpu_gfx_whole_wave";
409+
break;
407410
case CallingConv::M68k_RTD: Out << "m68k_rtdcc"; break;
408411
case CallingConv::RISCV_VectorCall:
409412
Out << "riscv_vector_cc";

llvm/lib/IR/Function.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,7 @@ bool llvm::CallingConv::supportsNonVoidReturnType(CallingConv::ID CC) {
12321232
case CallingConv::AArch64_SVE_VectorCall:
12331233
case CallingConv::WASM_EmscriptenInvoke:
12341234
case CallingConv::AMDGPU_Gfx:
1235+
case CallingConv::AMDGPU_Gfx_WholeWave:
12351236
case CallingConv::M68k_INTR:
12361237
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X0:
12371238
case CallingConv::AArch64_SME_ABI_Support_Routines_PreserveMost_From_X2:

llvm/lib/IR/Verifier.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2979,6 +2979,16 @@ void Verifier::visitFunction(const Function &F) {
29792979
"perfect forwarding!",
29802980
&F);
29812981
break;
2982+
case CallingConv::AMDGPU_Gfx_WholeWave:
2983+
Check(!F.arg_empty() && F.arg_begin()->getType()->isIntegerTy(1),
2984+
"Calling convention requires first argument to be i1", &F);
2985+
Check(!F.arg_begin()->hasInRegAttr(),
2986+
"Calling convention requires first argument to not be inreg", &F);
2987+
Check(!F.isVarArg(),
2988+
"Calling convention does not support varargs or "
2989+
"perfect forwarding!",
2990+
&F);
2991+
break;
29822992
}
29832993

29842994
// Check that the argument values match the function type for this function...

llvm/lib/Target/AMDGPU/AMDGPUCallLowering.cpp

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -374,15 +374,20 @@ bool AMDGPUCallLowering::lowerReturn(MachineIRBuilder &B, const Value *Val,
374374
return true;
375375
}
376376

377-
unsigned ReturnOpc =
378-
IsShader ? AMDGPU::SI_RETURN_TO_EPILOG : AMDGPU::SI_RETURN;
377+
const bool IsWholeWave = MFI->isWholeWaveFunction();
378+
unsigned ReturnOpc = IsWholeWave ? AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_RETURN
379+
: IsShader ? AMDGPU::SI_RETURN_TO_EPILOG
380+
: AMDGPU::SI_RETURN;
379381
auto Ret = B.buildInstrNoInsert(ReturnOpc);
380382

381383
if (!FLI.CanLowerReturn)
382384
insertSRetStores(B, Val->getType(), VRegs, FLI.DemoteRegister);
383385
else if (!lowerReturnVal(B, Val, VRegs, Ret))
384386
return false;
385387

388+
if (IsWholeWave)
389+
addOriginalExecToReturn(B.getMF(), Ret);
390+
386391
// TODO: Handle CalleeSavedRegsViaCopy.
387392

388393
B.insertInstr(Ret);
@@ -632,6 +637,17 @@ bool AMDGPUCallLowering::lowerFormalArguments(
632637
if (DL.getTypeStoreSize(Arg.getType()) == 0)
633638
continue;
634639

640+
if (Info->isWholeWaveFunction() && Idx == 0) {
641+
assert(VRegs[Idx].size() == 1 && "Expected only one register");
642+
643+
// The first argument for whole wave functions is the original EXEC value.
644+
B.buildInstr(AMDGPU::G_AMDGPU_WHOLE_WAVE_FUNC_SETUP)
645+
.addDef(VRegs[Idx][0]);
646+
647+
++Idx;
648+
continue;
649+
}
650+
635651
const bool InReg = Arg.hasAttribute(Attribute::InReg);
636652

637653
if (Arg.hasAttribute(Attribute::SwiftSelf) ||
@@ -1347,6 +1363,7 @@ bool AMDGPUCallLowering::lowerTailCall(
13471363
SmallVector<std::pair<MCRegister, Register>, 12> ImplicitArgRegs;
13481364

13491365
if (Info.CallConv != CallingConv::AMDGPU_Gfx &&
1366+
Info.CallConv != CallingConv::AMDGPU_Gfx_WholeWave &&
13501367
!AMDGPU::isChainCC(Info.CallConv)) {
13511368
// With a fixed ABI, allocate fixed registers before user arguments.
13521369
if (!passSpecialInputs(MIRBuilder, CCInfo, ImplicitArgRegs, Info))
@@ -1524,7 +1541,8 @@ bool AMDGPUCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
15241541
// after the ordinary user argument registers.
15251542
SmallVector<std::pair<MCRegister, Register>, 12> ImplicitArgRegs;
15261543

1527-
if (Info.CallConv != CallingConv::AMDGPU_Gfx) {
1544+
if (Info.CallConv != CallingConv::AMDGPU_Gfx &&
1545+
Info.CallConv != CallingConv::AMDGPU_Gfx_WholeWave) {
15281546
// With a fixed ABI, allocate fixed registers before user arguments.
15291547
if (!passSpecialInputs(MIRBuilder, CCInfo, ImplicitArgRegs, Info))
15301548
return false;
@@ -1592,3 +1610,11 @@ bool AMDGPUCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
15921610

15931611
return true;
15941612
}
1613+
1614+
void AMDGPUCallLowering::addOriginalExecToReturn(
1615+
MachineFunction &MF, MachineInstrBuilder &Ret) const {
1616+
const GCNSubtarget &ST = MF.getSubtarget<GCNSubtarget>();
1617+
const SIInstrInfo *TII = ST.getInstrInfo();
1618+
const MachineInstr *Setup = TII->getWholeWaveFunctionSetup(MF);
1619+
Ret.addReg(Setup->getOperand(0).getReg());
1620+
}

llvm/lib/Target/AMDGPU/AMDGPUCallLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class AMDGPUCallLowering final : public CallLowering {
3737
bool lowerReturnVal(MachineIRBuilder &B, const Value *Val,
3838
ArrayRef<Register> VRegs, MachineInstrBuilder &Ret) const;
3939

40+
void addOriginalExecToReturn(MachineFunction &MF,
41+
MachineInstrBuilder &Ret) const;
42+
4043
public:
4144
AMDGPUCallLowering(const AMDGPUTargetLowering &TLI);
4245

0 commit comments

Comments
 (0)