Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a525bba
Add subtarget feature
rovka Jan 24, 2025
1ceab6a
[AMDGPU] ISel & PEI for whole wave functions
rovka Jan 27, 2025
399e08c
Use MF instead of MBB
rovka Mar 17, 2025
8f72b59
Revert "Add subtarget feature"
rovka Mar 11, 2025
accbe8e
Add new CC. Do nothing
rovka Mar 19, 2025
1a82d88
Replace SubtargetFeature with CallingConv
rovka Mar 11, 2025
ea3821b
Enable gisel in tests
rovka Mar 17, 2025
1b20edd
GISel support
rovka Mar 11, 2025
5e97750
Rename pseudo to match others
rovka Mar 19, 2025
be094ce
Rename CC
rovka Mar 25, 2025
b1a17c6
Fix formatting
rovka Mar 25, 2025
75017e9
Merge branch 'main' into whole-wave-funcs
rovka Apr 7, 2025
4c6beec
Merge remote-tracking branch 'remotes/origin/main' into whole-wave-funcs
rovka Apr 30, 2025
80e6433
Update tests after merge
rovka May 6, 2025
552e220
Fix bug in testcase
rovka May 6, 2025
7ed7e96
Test inreg args
rovka May 19, 2025
8325ef1
Merge remote-tracking branch 'remotes/origin/main' into whole-wave-funcs
rovka May 20, 2025
e1f133e
Add docs and fixme
rovka May 20, 2025
ac70a87
Remove kill flags on orig exec mask
rovka Jun 17, 2025
08102a3
Add helper to add orig exec to return
rovka Jun 23, 2025
1cd402f
Test with single use of orig exec
rovka Jun 23, 2025
e8fc4bd
Test calling gfx func from wwf
rovka Jun 23, 2025
8feed10
Test wave64
rovka Jun 24, 2025
bc7b9ef
Merge remote-tracking branch 'remotes/origin/main' into whole-wave-funcs
rovka Jun 24, 2025
ba08290
Merge remote-tracking branch 'remotes/origin/main' into whole-wave-funcs
rovka Jun 24, 2025
bc8d8ce
Fix a few missed spots
rovka Jun 24, 2025
0eb6c66
clang-format
rovka Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -1251,6 +1251,12 @@ def FeatureXF32Insts : SubtargetFeature<"xf32-insts",
"v_mfma_f32_16x16x8_xf32 and v_mfma_f32_32x32x4_xf32"
>;

def FeatureWholeWaveFunction : SubtargetFeature<"whole-wave-function",
"IsWholeWaveFunction",
"true",
"Current function is a whole wave function (runs with all lanes enabled)"
>;

// Dummy feature used to disable assembler instructions.
def FeatureDisable : SubtargetFeature<"",
"FeatureDisable","true",
Expand Down Expand Up @@ -2532,6 +2538,8 @@ def HasXF32Insts : Predicate<"Subtarget->hasXF32Insts()">,
def HasAshrPkInsts : Predicate<"Subtarget->hasAshrPkInsts()">,
AssemblerPredicate<(all_of FeatureAshrPkInsts)>;

def IsWholeWaveFunction : Predicate<"Subtarget->isWholeWaveFunction()">;

// Include AMDGPU TD files
include "SISchedule.td"
include "GCNProcessors.td"
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5630,6 +5630,8 @@ const char* AMDGPUTargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(BUFFER_ATOMIC_FMIN)
NODE_NAME_CASE(BUFFER_ATOMIC_FMAX)
NODE_NAME_CASE(BUFFER_ATOMIC_COND_SUB_U32)
NODE_NAME_CASE(WHOLE_WAVE_SETUP)
NODE_NAME_CASE(WHOLE_WAVE_RETURN)
}
return nullptr;
}
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,12 @@ enum NodeType : unsigned {
BUFFER_ATOMIC_FMAX,
BUFFER_ATOMIC_COND_SUB_U32,
LAST_MEMORY_OPCODE = BUFFER_ATOMIC_COND_SUB_U32,

// Set up a whole wave function.
WHOLE_WAVE_SETUP,

// Return from a whole wave function.
WHOLE_WAVE_RETURN,
};

} // End namespace AMDGPUISD
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPUInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,17 @@ def AMDGPUfdot2_impl : SDNode<"AMDGPUISD::FDOT2",

def AMDGPUperm_impl : SDNode<"AMDGPUISD::PERM", AMDGPUDTIntTernaryOp, []>;

// Marks the entry into a whole wave function.
def AMDGPUwhole_wave_setup : SDNode<
"AMDGPUISD::WHOLE_WAVE_SETUP", SDTypeProfile<1, 0, [SDTCisInt<0>]>,
[SDNPHasChain, SDNPSideEffect]>;

// Marks the return from a whole wave function.
def AMDGPUwhole_wave_return : SDNode<
"AMDGPUISD::WHOLE_WAVE_RETURN", SDTNone,
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]
>;

// SI+ export
def AMDGPUExportOp : SDTypeProfile<0, 8, [
SDTCisInt<0>, // i8 tgt
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AMDGPU/GCNSubtarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ class GCNSubtarget final : public AMDGPUGenSubtargetInfo,

bool RequiresCOV6 = false;

bool IsWholeWaveFunction = false;

// Dummy feature to use for assembler in tablegen.
bool FeatureDisable = false;

Expand Down Expand Up @@ -1448,6 +1450,10 @@ class GCNSubtarget final : public AMDGPUGenSubtargetInfo,
// of sign-extending.
bool hasGetPCZeroExtension() const { return GFX12Insts; }

/// \returns true if the current function is a whole wave function (i.e. it
/// runs with all the lanes enabled).
bool isWholeWaveFunction() const { return IsWholeWaveFunction; }

/// \returns SGPR allocation granularity supported by the subtarget.
unsigned getSGPRAllocGranule() const {
return AMDGPU::IsaInfo::getSGPRAllocGranule(this);
Expand Down
81 changes: 71 additions & 10 deletions llvm/lib/Target/AMDGPU/SIFrameLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,18 @@ static Register buildScratchExecCopy(LiveRegUnits &LiveUnits,

initLiveUnits(LiveUnits, TRI, FuncInfo, MF, MBB, MBBI, IsProlog);

ScratchExecCopy = findScratchNonCalleeSaveRegister(
MRI, LiveUnits, *TRI.getWaveMaskRegClass());
if (ST.isWholeWaveFunction()) {
// Whole wave functions already have a copy of the original EXEC mask that
// we can use.
assert(IsProlog && "Epilog should look at return, not setup");
ScratchExecCopy =
TII->getWholeWaveFunctionSetup(MBB)->getOperand(0).getReg();
assert(ScratchExecCopy && "Couldn't find copy of EXEC");
} else {
ScratchExecCopy = findScratchNonCalleeSaveRegister(
MRI, LiveUnits, *TRI.getWaveMaskRegClass());
}

if (!ScratchExecCopy)
report_fatal_error("failed to find free scratch register");

Expand Down Expand Up @@ -950,10 +960,15 @@ void SIFrameLowering::emitCSRSpillStores(
};

StoreWWMRegisters(WWMScratchRegs);

auto EnableAllLanes = [&]() {
unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addImm(-1);
};

if (!WWMCalleeSavedRegs.empty()) {
if (ScratchExecCopy) {
unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addImm(-1);
EnableAllLanes();
} else {
ScratchExecCopy = buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL,
/*IsProlog*/ true,
Expand All @@ -962,7 +977,15 @@ void SIFrameLowering::emitCSRSpillStores(
}

StoreWWMRegisters(WWMCalleeSavedRegs);
if (ScratchExecCopy) {
if (ST.isWholeWaveFunction()) {
// SI_SETUP_WHOLE_WAVE_FUNCTION has outlived its purpose, so we can remove
// it now. If we have already saved some WWM CSR registers, then the EXEC is
// already -1 and we don't need to do anything else. Otherwise, set EXEC to
// -1 here.
if (WWMCalleeSavedRegs.empty())
EnableAllLanes();
TII->getWholeWaveFunctionSetup(MBB)->eraseFromParent();
} else if (ScratchExecCopy) {
// FIXME: Split block and make terminator.
unsigned ExecMov = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
BuildMI(MBB, MBBI, DL, TII->get(ExecMov), TRI.getExec())
Expand Down Expand Up @@ -1037,11 +1060,6 @@ void SIFrameLowering::emitCSRSpillRestores(
Register ScratchExecCopy;
SmallVector<std::pair<Register, int>, 2> WWMCalleeSavedRegs, WWMScratchRegs;
FuncInfo->splitWWMSpillRegisters(MF, WWMCalleeSavedRegs, WWMScratchRegs);
if (!WWMScratchRegs.empty())
ScratchExecCopy =
buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL,
/*IsProlog*/ false, /*EnableInactiveLanes*/ true);

auto RestoreWWMRegisters =
[&](SmallVectorImpl<std::pair<Register, int>> &WWMRegs) {
for (const auto &Reg : WWMRegs) {
Expand All @@ -1052,6 +1070,36 @@ void SIFrameLowering::emitCSRSpillRestores(
}
};

if (ST.isWholeWaveFunction()) {
// For whole wave functions, the EXEC is already -1 at this point.
// Therefore, we can restore the CSR WWM registers right away.
RestoreWWMRegisters(WWMCalleeSavedRegs);

// The original EXEC is the first operand of the return instruction.
const MachineInstr &Return = MBB.instr_back();
assert(Return.getOpcode() == AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN &&
"Unexpected return inst");
Register OrigExec = Return.getOperand(0).getReg();

if (!WWMScratchRegs.empty()) {
unsigned XorOpc = ST.isWave32() ? AMDGPU::S_XOR_B32 : AMDGPU::S_XOR_B64;
BuildMI(MBB, MBBI, DL, TII->get(XorOpc), TRI.getExec())
.addReg(OrigExec)
.addImm(-1);
RestoreWWMRegisters(WWMScratchRegs);
}

// Restore original EXEC.
unsigned MovOpc = ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64;
BuildMI(MBB, MBBI, DL, TII->get(MovOpc), TRI.getExec()).addReg(OrigExec);
return;
}

if (!WWMScratchRegs.empty())
ScratchExecCopy =
buildScratchExecCopy(LiveUnits, MF, MBB, MBBI, DL,
/*IsProlog*/ false, /*EnableInactiveLanes*/ true);

RestoreWWMRegisters(WWMScratchRegs);
if (!WWMCalleeSavedRegs.empty()) {
if (ScratchExecCopy) {
Expand Down Expand Up @@ -1588,6 +1636,7 @@ void SIFrameLowering::determineCalleeSaves(MachineFunction &MF,
NeedExecCopyReservedReg = true;
else if (MI.getOpcode() == AMDGPU::SI_RETURN ||
MI.getOpcode() == AMDGPU::SI_RETURN_TO_EPILOG ||
MI.getOpcode() == AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN ||
(MFI->isChainFunction() &&
TII->isChainCallOpcode(MI.getOpcode()))) {
// We expect all return to be the same size.
Expand Down Expand Up @@ -1616,6 +1665,18 @@ void SIFrameLowering::determineCalleeSaves(MachineFunction &MF,
if (MFI->isEntryFunction())
return;

if (ST.isWholeWaveFunction()) {
// In practice, all the VGPRs are WWM registers, and we will need to save at
// least their inactive lanes. Add them to WWMReservedRegs.
assert(!NeedExecCopyReservedReg && "Whole wave functions can use the reg mapped for their i1 argument");
for (MCRegister Reg : AMDGPU::VGPR_32RegClass)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to be expensive.
It's probably fine for making this work, but long term I think we'd need to do this differently.

if (MF.getRegInfo().isPhysRegModified(Reg)) {
MFI->reserveWWMRegister(Reg);
MF.begin()->addLiveIn(Reg);
}
MF.begin()->sortUniqueLiveIns();
}

// Remove any VGPRs used in the return value because these do not need to be saved.
// This prevents CSR restore from clobbering return VGPRs.
if (ReturnMI) {
Expand Down
30 changes: 27 additions & 3 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2877,6 +2877,8 @@ SDValue SITargetLowering::LowerFormalArguments(
!Info->hasWorkGroupIDZ());
}

bool IsWholeWaveFunc = getSubtarget()->isWholeWaveFunction();

if (CallConv == CallingConv::AMDGPU_PS) {
processPSInputArgs(Splits, CallConv, Ins, Skipped, FType, Info);

Expand Down Expand Up @@ -2917,7 +2919,8 @@ SDValue SITargetLowering::LowerFormalArguments(
} else if (IsKernel) {
assert(Info->hasWorkGroupIDX() && Info->hasWorkItemIDX());
} else {
Splits.append(Ins.begin(), Ins.end());
Splits.append(IsWholeWaveFunc ? std::next(Ins.begin()) : Ins.begin(),
Ins.end());
}

if (IsKernel)
Expand Down Expand Up @@ -2948,14 +2951,22 @@ SDValue SITargetLowering::LowerFormalArguments(

SmallVector<SDValue, 16> Chains;

if (IsWholeWaveFunc) {
SDValue Setup = DAG.getNode(AMDGPUISD::WHOLE_WAVE_SETUP, DL,
{MVT::i1, MVT::Other}, Chain);
InVals.push_back(Setup.getValue(0));
Chains.push_back(Setup.getValue(1));
}

// FIXME: This is the minimum kernel argument alignment. We should improve
// this to the maximum alignment of the arguments.
//
// FIXME: Alignment of explicit arguments totally broken with non-0 explicit
// kern arg offset.
const Align KernelArgBaseAlign = Align(16);

for (unsigned i = 0, e = Ins.size(), ArgIdx = 0; i != e; ++i) {
for (unsigned i = IsWholeWaveFunc ? 1 : 0, e = Ins.size(), ArgIdx = 0; i != e;
++i) {
const ISD::InputArg &Arg = Ins[i];
if ((Arg.isOrigArg() && Skipped[Arg.getOrigArgIndex()]) || IsError) {
InVals.push_back(DAG.getUNDEF(Arg.VT));
Expand Down Expand Up @@ -3300,7 +3311,9 @@ SITargetLowering::LowerReturn(SDValue Chain, CallingConv::ID CallConv,

unsigned Opc = AMDGPUISD::ENDPGM;
if (!IsWaveEnd)
Opc = IsShader ? AMDGPUISD::RETURN_TO_EPILOG : AMDGPUISD::RET_GLUE;
Opc = Subtarget->isWholeWaveFunction() ? AMDGPUISD::WHOLE_WAVE_RETURN
: IsShader ? AMDGPUISD::RETURN_TO_EPILOG
: AMDGPUISD::RET_GLUE;
return DAG.getNode(Opc, DL, MVT::Other, RetOps);
}

Expand Down Expand Up @@ -5670,6 +5683,17 @@ SITargetLowering::EmitInstrWithCustomInserter(MachineInstr &MI,
MI.eraseFromParent();
return SplitBB;
}
case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN: {
assert(Subtarget->isWholeWaveFunction());

// During ISel, it's difficult to propagate the original EXEC mask to use as
// an input to SI_WHOLE_WAVE_FUNC_RETURN. Set it up here instead.
MachineInstr *Setup =
TII->getWholeWaveFunctionSetup(*BB->getParent()->begin());
assert(Setup && "Couldn't find SI_SETUP_WHOLE_WAVE_FUNC");
MI.getOperand(0).setReg(Setup->getOperand(0).getReg());
return BB;
}
default:
if (TII->isImage(MI) || TII->isMUBUF(MI)) {
if (!MI.mayStore())
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1636,6 +1636,7 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
// with knowledge of the called routines.
if (MI.getOpcode() == AMDGPU::SI_RETURN_TO_EPILOG ||
MI.getOpcode() == AMDGPU::SI_RETURN ||
MI.getOpcode() == AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN ||
MI.getOpcode() == AMDGPU::S_SETPC_B64_return ||
(MI.isReturn() && MI.isCall() && !callWaitsOnFunctionEntry(MI))) {
Wait = Wait.combined(WCG->getAllZeroWaitcnt(/*IncludeVSCnt=*/false));
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AMDGPU/SIInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2509,6 +2509,7 @@ bool SIInstrInfo::expandPostRAPseudo(MachineInstr &MI) const {
MI.setDesc(get(ST.isWave32() ? AMDGPU::S_MOV_B32 : AMDGPU::S_MOV_B64));
break;
}
case AMDGPU::SI_WHOLE_WAVE_FUNC_RETURN:
case AMDGPU::SI_RETURN: {
const MachineFunction *MF = MBB.getParent();
const GCNSubtarget &ST = MF->getSubtarget<GCNSubtarget>();
Expand Down Expand Up @@ -5773,6 +5774,16 @@ void SIInstrInfo::restoreExec(MachineFunction &MF, MachineBasicBlock &MBB,
Indexes->insertMachineInstrInMaps(*ExecRestoreMI);
}

MachineInstr *
SIInstrInfo::getWholeWaveFunctionSetup(MachineBasicBlock &MBB) const {
assert(ST.isWholeWaveFunction() && "Not a whole wave func");
for (MachineInstr &MI : MBB)
if (MI.getOpcode() == AMDGPU::SI_SETUP_WHOLE_WAVE_FUNC)
return &MI;

llvm_unreachable("Couldn't find instruction. Wrong MBB?");
}

static const TargetRegisterClass *
adjustAllocatableRegClass(const GCNSubtarget &ST, const SIRegisterInfo &RI,
const MachineRegisterInfo &MRI,
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AMDGPU/SIInstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,8 @@ class SIInstrInfo final : public AMDGPUGenInstrInfo {
MachineBasicBlock::iterator MBBI, const DebugLoc &DL,
Register Reg, SlotIndexes *Indexes = nullptr) const;

MachineInstr *getWholeWaveFunctionSetup(MachineBasicBlock &MBB) const;

/// Return the correct register class for \p OpNo. For target-specific
/// instructions, this will return the register class that has been defined
/// in tablegen. For generic instructions, like REG_SEQUENCE it will return
Expand Down
29 changes: 29 additions & 0 deletions llvm/lib/Target/AMDGPU/SIInstructions.td
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,35 @@ def SI_INIT_WHOLE_WAVE : SPseudoInstSI <
let isConvergent = 1;
}

let SubtargetPredicate = IsWholeWaveFunction in {
// Sets EXEC to all lanes and returns the previous EXEC.
def SI_SETUP_WHOLE_WAVE_FUNC : SPseudoInstSI <
(outs SReg_1:$dst), (ins), [(set i1:$dst, (AMDGPUwhole_wave_setup))]> {
let Defs = [EXEC];
let Uses = [EXEC];

let isConvergent = 1;
}

// Restores the previous EXEC and otherwise behaves entirely like a SI_RETURN.
def SI_WHOLE_WAVE_FUNC_RETURN : SPseudoInstSI <
(outs), (ins SReg_1:$orig_exec)> {
let isTerminator = 1;
let isBarrier = 1;
let isReturn = 1;
let SchedRW = [WriteBranch];

// We're going to use custom handling to set the $orig_exec to the correct value.
let usesCustomInserter = 1;
}

// Generate a SI_WHOLE_WAVE_FUNC_RETURN pseudo with a placeholder for its
// argument. It will be filled in by the custom inserter.
def : GCNPat<
(AMDGPUwhole_wave_return), (SI_WHOLE_WAVE_FUNC_RETURN (i1 (IMPLICIT_DEF)))>;

} // SubtargetPredicate = IsWholeWaveFunction

// Return for returning shaders to a shader variant epilog.
def SI_RETURN_TO_EPILOG : SPseudoInstSI <
(outs), (ins variable_ops), [(AMDGPUreturn_to_epilog)]> {
Expand Down
Loading
Loading