Skip to content

Commit 4cf04ad

Browse files
authored
[AMDGPU] Hazard handling for gfx1250 wmma instructions (llvm#149865) (llvm#3973)
2 parents 7a43484 + f2922f0 commit 4cf04ad

File tree

4 files changed

+2513
-3
lines changed

4 files changed

+2513
-3
lines changed

llvm/lib/Target/AMDGPU/GCNHazardRecognizer.cpp

Lines changed: 180 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -525,8 +525,8 @@ static int getWaitStatesSince(GCNHazardRecognizer::IsHazardFn IsHazard,
525525
const MachineInstr *MI, IsExpiredFn IsExpired) {
526526
DenseSet<const MachineBasicBlock *> Visited;
527527
return getWaitStatesSince(IsHazard, MI->getParent(),
528-
std::next(MI->getReverseIterator()),
529-
0, IsExpired, Visited);
528+
std::next(MI->getReverseIterator()), 0, IsExpired,
529+
Visited, SIInstrInfo::getNumWaitStates);
530530
}
531531

532532
int GCNHazardRecognizer::getWaitStatesSince(IsHazardFn IsHazard, int Limit) {
@@ -1195,7 +1195,8 @@ void GCNHazardRecognizer::fixHazards(MachineInstr *MI) {
11951195
fixVALUPartialForwardingHazard(MI);
11961196
fixVALUTransUseHazard(MI);
11971197
fixVALUTransCoexecutionHazards(MI);
1198-
fixWMMAHazards(MI);
1198+
fixWMMAHazards(MI); // fall-through if co-execution is enabled.
1199+
fixWMMACoexecutionHazards(MI);
11991200
fixShift64HighRegBug(MI);
12001201
fixVALUMaskWriteHazard(MI);
12011202
fixRequiredExportPriority(MI);
@@ -1925,6 +1926,182 @@ bool GCNHazardRecognizer::fixWMMAHazards(MachineInstr *MI) {
19251926
return true;
19261927
}
19271928

1929+
static bool isCoexecutableVALUInst(const MachineInstr &MI) {
1930+
return SIInstrInfo::isVALU(MI) && !SIInstrInfo::isTRANS(MI) &&
1931+
!SIInstrInfo::isWMMA(MI) && !SIInstrInfo::isSWMMAC(MI); // What else?
1932+
}
1933+
1934+
static bool IsWMMAHazardInstInCategory(const MachineInstr &MI,
1935+
const SIInstrInfo *TII, unsigned Latency,
1936+
unsigned Category) {
1937+
assert(TII->isXDLWMMA(MI) && (Latency == 8 || Latency == 16) &&
1938+
"Handle me if the xdl wmma instruction latency changes");
1939+
1940+
switch (Category) {
1941+
case 0: // Dense WMMA Instructions:
1942+
// WMMA_*F16, WMMA_*BF16
1943+
// WMMA_*FP8FP8
1944+
// WMMA_*FP8BF8
1945+
// WMMA_*BF8FP8
1946+
// WMMA_*BF8BF8
1947+
// WMMA_*F8F6F4 if SRCA & SRCB != F8
1948+
return Latency == 8 && SIInstrInfo::isWMMA(MI);
1949+
1950+
case 1: // Dense WMMA Instructions:
1951+
// WMMA_IU8
1952+
// WMMA_IU4
1953+
// WMMA_*F8F6F4 if SRCA OR SRCB == F8
1954+
return Latency == 16 && SIInstrInfo::isWMMA(MI);
1955+
1956+
case 2: // Dense SWMMAC Instructions
1957+
// SWMMAC_*F16, SWMMAC_*BF16,
1958+
// SWMMAC_*FP8FP8
1959+
// SWMMAC_*BF8FP8
1960+
// SWMMAC_*FP8BF8
1961+
// SWMMAC_*BF8BF8
1962+
return Latency == 8 && SIInstrInfo::isSWMMAC(MI);
1963+
1964+
case 3: // Sparse WMMA Instructions:
1965+
// SWMMAC_IU8
1966+
// SWMMAC_IU4
1967+
return Latency == 16 && SIInstrInfo::isSWMMAC(MI);
1968+
default:
1969+
break;
1970+
} // end switch.
1971+
1972+
return false;
1973+
}
1974+
1975+
bool GCNHazardRecognizer::fixWMMACoexecutionHazards(MachineInstr *MI) {
1976+
if (!AMDGPU::isGFX1250(ST))
1977+
return false;
1978+
1979+
const SIInstrInfo *TII = ST.getInstrInfo();
1980+
if (!TII->isXDLWMMA(*MI) && !isCoexecutableVALUInst(*MI))
1981+
return false;
1982+
1983+
const SIRegisterInfo *TRI = ST.getRegisterInfo();
1984+
1985+
// WaitStates here is the number of V_NOPs or unrelated VALU instructions must
1986+
// be in between the first WMMA and the second instruction to cover the hazard
1987+
// (WMMAWaitStates if the second is also a WMMA, VALUWaitStates if the second
1988+
// is a VALU). Refer to SPG 4.6.12.1. "Requirements for WMMA data hazards" for
1989+
// numbers, which depends on the category of the first WMMA.
1990+
const int WMMAWaitStates[] = {5, 9, 3, 5};
1991+
const int VALUWaitStates[] = {4, 8, 2, 4};
1992+
unsigned Category = 0;
1993+
1994+
auto IsWMMAHazardFn = [MI, TII, TRI, &Category, this](const MachineInstr &I) {
1995+
if (!TII->isXDLWMMA(I))
1996+
return false;
1997+
1998+
unsigned Latency = TSchedModel.computeInstrLatency(&I);
1999+
if (!IsWMMAHazardInstInCategory(I, TII, Latency, Category))
2000+
return false;
2001+
2002+
Register D0 = TII->getNamedOperand(I, AMDGPU::OpName::vdst)->getReg();
2003+
Register A1 = TII->getNamedOperand(*MI, AMDGPU::OpName::src0)->getReg();
2004+
Register B1 = TII->getNamedOperand(*MI, AMDGPU::OpName::src1)->getReg();
2005+
2006+
// WMMA0 wrires (D0), WMMA1 reads (A1/B1/Idx1).
2007+
if (TRI->regsOverlap(D0, A1) || TRI->regsOverlap(D0, B1))
2008+
return true;
2009+
2010+
if (SIInstrInfo::isSWMMAC(*MI)) {
2011+
Register Idx1 = TII->getNamedOperand(*MI, AMDGPU::OpName::src2)->getReg();
2012+
if (TRI->regsOverlap(D0, Idx1))
2013+
return true;
2014+
}
2015+
2016+
return false;
2017+
};
2018+
2019+
auto IsVALUHazardFn = [MI, TII, TRI, &Category, this](const MachineInstr &I) {
2020+
if (!TII->isXDLWMMA(I))
2021+
return false;
2022+
2023+
unsigned Latency = TSchedModel.computeInstrLatency(&I);
2024+
if (!IsWMMAHazardInstInCategory(I, TII, Latency, Category))
2025+
return false;
2026+
2027+
// WMMA writes, VALU reads.
2028+
Register D0 = TII->getNamedOperand(I, AMDGPU::OpName::vdst)->getReg();
2029+
for (const MachineOperand &ValuUse : MI->explicit_uses()) {
2030+
if (ValuUse.isReg() && TRI->regsOverlap(D0, ValuUse.getReg()))
2031+
return true;
2032+
}
2033+
2034+
auto *ValuDst = TII->getNamedOperand(*MI, AMDGPU::OpName::vdst);
2035+
if (!ValuDst || !ValuDst->isReg())
2036+
return false;
2037+
Register D1 = ValuDst->getReg();
2038+
2039+
// WMMA writes, VALU writes.
2040+
if (TRI->regsOverlap(D0, D1))
2041+
return true;
2042+
2043+
// WMMA reads, VALU writes.
2044+
Register A0 = TII->getNamedOperand(I, AMDGPU::OpName::src0)->getReg();
2045+
Register B0 = TII->getNamedOperand(I, AMDGPU::OpName::src1)->getReg();
2046+
if (TRI->regsOverlap(A0, D1) || TRI->regsOverlap(B0, D1))
2047+
return true;
2048+
2049+
if (SIInstrInfo::isSWMMAC(I)) {
2050+
Register Idx0 = TII->getNamedOperand(I, AMDGPU::OpName::src2)->getReg();
2051+
if (TRI->regsOverlap(D1, Idx0))
2052+
return true;
2053+
}
2054+
2055+
return false;
2056+
};
2057+
2058+
int Limit = 0;
2059+
auto IsExpiredFn = [&Limit](const MachineInstr &, int WaitStates) {
2060+
return WaitStates >= Limit;
2061+
};
2062+
2063+
auto GetWaitStatesFn = [](const MachineInstr &I) {
2064+
return SIInstrInfo::isVALU(I) ? 1 : 0;
2065+
};
2066+
2067+
int WaitStatesNeeded = -1;
2068+
if (TII->isXDLWMMA(*MI)) {
2069+
for (Category = 0; WaitStatesNeeded < 0 && Category < 4; Category++) {
2070+
Limit = WMMAWaitStates[Category]; // for IsExpiredFn.
2071+
DenseSet<const MachineBasicBlock *> Visited;
2072+
// '::getWaitStatesSince' returns the number of VALUs in between if hazard
2073+
// exists, and INT_MAX if there is no hazard. As a result, a negative
2074+
// WaitStatesNeeded here means no hazard, and we will continue to search
2075+
// for other categories.
2076+
WaitStatesNeeded =
2077+
Limit - ::getWaitStatesSince(IsWMMAHazardFn, MI->getParent(),
2078+
std::next(MI->getReverseIterator()), 0,
2079+
IsExpiredFn, Visited, GetWaitStatesFn);
2080+
}
2081+
} else { // Must be a co-executable VALU.
2082+
for (Category = 0; WaitStatesNeeded < 0 && Category < 4; Category++) {
2083+
Limit = VALUWaitStates[Category]; // for IsExpiredFn.
2084+
DenseSet<const MachineBasicBlock *> Visited;
2085+
// '::getWaitStatesSince' returns the number of VALUs in between if hazard
2086+
// exists, and INT_MAX if there is no hazard. As a result, a negative
2087+
// WaitStatesNeeded here means no hazard, and we will continue to search
2088+
// for other categories.
2089+
WaitStatesNeeded =
2090+
Limit - ::getWaitStatesSince(IsVALUHazardFn, MI->getParent(),
2091+
std::next(MI->getReverseIterator()), 0,
2092+
IsExpiredFn, Visited, GetWaitStatesFn);
2093+
}
2094+
}
2095+
2096+
// WaitStatesNeeded now is the number of V_NOPs we need to insert, negative
2097+
// means not needed.
2098+
for (int i = 0; i < WaitStatesNeeded; i++)
2099+
BuildMI(*MI->getParent(), MI, MI->getDebugLoc(),
2100+
TII->get(AMDGPU::V_NOP_e32));
2101+
2102+
return true;
2103+
}
2104+
19282105
bool GCNHazardRecognizer::fixShift64HighRegBug(MachineInstr *MI) {
19292106
if (!ST.hasShift64HighRegBug())
19302107
return false;

llvm/lib/Target/AMDGPU/GCNHazardRecognizer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class GCNHazardRecognizer final : public ScheduleHazardRecognizer {
106106
bool fixVALUTransUseHazard(MachineInstr *MI);
107107
bool fixVALUTransCoexecutionHazards(MachineInstr *MI);
108108
bool fixWMMAHazards(MachineInstr *MI);
109+
bool fixWMMACoexecutionHazards(MachineInstr *MI);
109110
bool fixShift64HighRegBug(MachineInstr *MI);
110111
bool fixVALUMaskWriteHazard(MachineInstr *MI);
111112
bool fixRequiredExportPriority(MachineInstr *MI);

0 commit comments

Comments
 (0)