@@ -520,8 +520,8 @@ static int getWaitStatesSince(GCNHazardRecognizer::IsHazardFn IsHazard,
520520 const MachineInstr *MI, IsExpiredFn IsExpired) {
521521 DenseSet<const MachineBasicBlock *> Visited;
522522 return getWaitStatesSince (IsHazard, MI->getParent (),
523- std::next (MI->getReverseIterator ()),
524- 0 , IsExpired, Visited );
523+ std::next (MI->getReverseIterator ()), 0 , IsExpired,
524+ Visited, SIInstrInfo::getNumWaitStates );
525525}
526526
527527int GCNHazardRecognizer::getWaitStatesSince (IsHazardFn IsHazard, int Limit) {
@@ -1190,7 +1190,8 @@ void GCNHazardRecognizer::fixHazards(MachineInstr *MI) {
11901190 fixVALUPartialForwardingHazard (MI);
11911191 fixVALUTransUseHazard (MI);
11921192 fixVALUTransCoexecutionHazards (MI);
1193- fixWMMAHazards (MI);
1193+ fixWMMAHazards (MI); // fall-through if co-execution is enabled.
1194+ fixWMMACoexecutionHazards (MI);
11941195 fixShift64HighRegBug (MI);
11951196 fixVALUMaskWriteHazard (MI);
11961197 fixRequiredExportPriority (MI);
@@ -1909,6 +1910,182 @@ bool GCNHazardRecognizer::fixWMMAHazards(MachineInstr *MI) {
19091910 return true ;
19101911}
19111912
1913+ static bool isCoexecutableVALUInst (const MachineInstr &MI) {
1914+ return SIInstrInfo::isVALU (MI) && !SIInstrInfo::isTRANS (MI) &&
1915+ !SIInstrInfo::isWMMA (MI) && !SIInstrInfo::isSWMMAC (MI); // What else?
1916+ }
1917+
1918+ static bool IsWMMAHazardInstInCategory (const MachineInstr &MI,
1919+ const SIInstrInfo *TII, unsigned Latency,
1920+ unsigned Category) {
1921+ assert (TII->isXDLWMMA (MI) && (Latency == 8 || Latency == 16 ) &&
1922+ " Handle me if the xdl wmma instruction latency changes" );
1923+
1924+ switch (Category) {
1925+ case 0 : // Dense WMMA Instructions:
1926+ // WMMA_*F16, WMMA_*BF16
1927+ // WMMA_*FP8FP8
1928+ // WMMA_*FP8BF8
1929+ // WMMA_*BF8FP8
1930+ // WMMA_*BF8BF8
1931+ // WMMA_*F8F6F4 if SRCA & SRCB != F8
1932+ return Latency == 8 && SIInstrInfo::isWMMA (MI);
1933+
1934+ case 1 : // Dense WMMA Instructions:
1935+ // WMMA_IU8
1936+ // WMMA_IU4
1937+ // WMMA_*F8F6F4 if SRCA OR SRCB == F8
1938+ return Latency == 16 && SIInstrInfo::isWMMA (MI);
1939+
1940+ case 2 : // Dense SWMMAC Instructions
1941+ // SWMMAC_*F16, SWMMAC_*BF16,
1942+ // SWMMAC_*FP8FP8
1943+ // SWMMAC_*BF8FP8
1944+ // SWMMAC_*FP8BF8
1945+ // SWMMAC_*BF8BF8
1946+ return Latency == 8 && SIInstrInfo::isSWMMAC (MI);
1947+
1948+ case 3 : // Sparse WMMA Instructions:
1949+ // SWMMAC_IU8
1950+ // SWMMAC_IU4
1951+ return Latency == 16 && SIInstrInfo::isSWMMAC (MI);
1952+ default :
1953+ break ;
1954+ } // end switch.
1955+
1956+ return false ;
1957+ }
1958+
1959+ bool GCNHazardRecognizer::fixWMMACoexecutionHazards (MachineInstr *MI) {
1960+ if (!AMDGPU::isGFX1250 (ST))
1961+ return false ;
1962+
1963+ const SIInstrInfo *TII = ST.getInstrInfo ();
1964+ if (!TII->isXDLWMMA (*MI) && !isCoexecutableVALUInst (*MI))
1965+ return false ;
1966+
1967+ const SIRegisterInfo *TRI = ST.getRegisterInfo ();
1968+
1969+ // WaitStates here is the number of V_NOPs or unrelated VALU instructions must
1970+ // be in between the first WMMA and the second instruction to cover the hazard
1971+ // (WMMAWaitStates if the second is also a WMMA, VALUWaitStates if the second
1972+ // is a VALU). Refer to SPG 4.6.12.1. "Requirements for WMMA data hazards" for
1973+ // numbers, which depends on the category of the first WMMA.
1974+ const int WMMAWaitStates[] = {5 , 9 , 3 , 5 };
1975+ const int VALUWaitStates[] = {4 , 8 , 2 , 4 };
1976+ unsigned Category = 0 ;
1977+
1978+ auto IsWMMAHazardFn = [MI, TII, TRI, &Category, this ](const MachineInstr &I) {
1979+ if (!TII->isXDLWMMA (I))
1980+ return false ;
1981+
1982+ unsigned Latency = TSchedModel.computeInstrLatency (&I);
1983+ if (!IsWMMAHazardInstInCategory (I, TII, Latency, Category))
1984+ return false ;
1985+
1986+ Register D0 = TII->getNamedOperand (I, AMDGPU::OpName::vdst)->getReg ();
1987+ Register A1 = TII->getNamedOperand (*MI, AMDGPU::OpName::src0)->getReg ();
1988+ Register B1 = TII->getNamedOperand (*MI, AMDGPU::OpName::src1)->getReg ();
1989+
1990+ // WMMA0 wrires (D0), WMMA1 reads (A1/B1/Idx1).
1991+ if (TRI->regsOverlap (D0, A1) || TRI->regsOverlap (D0, B1))
1992+ return true ;
1993+
1994+ if (SIInstrInfo::isSWMMAC (*MI)) {
1995+ Register Idx1 = TII->getNamedOperand (*MI, AMDGPU::OpName::src2)->getReg ();
1996+ if (TRI->regsOverlap (D0, Idx1))
1997+ return true ;
1998+ }
1999+
2000+ return false ;
2001+ };
2002+
2003+ auto IsVALUHazardFn = [MI, TII, TRI, &Category, this ](const MachineInstr &I) {
2004+ if (!TII->isXDLWMMA (I))
2005+ return false ;
2006+
2007+ unsigned Latency = TSchedModel.computeInstrLatency (&I);
2008+ if (!IsWMMAHazardInstInCategory (I, TII, Latency, Category))
2009+ return false ;
2010+
2011+ // WMMA writes, VALU reads.
2012+ Register D0 = TII->getNamedOperand (I, AMDGPU::OpName::vdst)->getReg ();
2013+ for (const MachineOperand &ValuUse : MI->explicit_uses ()) {
2014+ if (ValuUse.isReg () && TRI->regsOverlap (D0, ValuUse.getReg ()))
2015+ return true ;
2016+ }
2017+
2018+ auto *ValuDst = TII->getNamedOperand (*MI, AMDGPU::OpName::vdst);
2019+ if (!ValuDst || !ValuDst->isReg ())
2020+ return false ;
2021+ Register D1 = ValuDst->getReg ();
2022+
2023+ // WMMA writes, VALU writes.
2024+ if (TRI->regsOverlap (D0, D1))
2025+ return true ;
2026+
2027+ // WMMA reads, VALU writes.
2028+ Register A0 = TII->getNamedOperand (I, AMDGPU::OpName::src0)->getReg ();
2029+ Register B0 = TII->getNamedOperand (I, AMDGPU::OpName::src1)->getReg ();
2030+ if (TRI->regsOverlap (A0, D1) || TRI->regsOverlap (B0, D1))
2031+ return true ;
2032+
2033+ if (SIInstrInfo::isSWMMAC (I)) {
2034+ Register Idx0 = TII->getNamedOperand (I, AMDGPU::OpName::src2)->getReg ();
2035+ if (TRI->regsOverlap (D1, Idx0))
2036+ return true ;
2037+ }
2038+
2039+ return false ;
2040+ };
2041+
2042+ int Limit = 0 ;
2043+ auto IsExpiredFn = [&Limit](const MachineInstr &, int WaitStates) {
2044+ return WaitStates >= Limit;
2045+ };
2046+
2047+ auto GetWaitStatesFn = [](const MachineInstr &I) {
2048+ return SIInstrInfo::isVALU (I) ? 1 : 0 ;
2049+ };
2050+
2051+ int WaitStatesNeeded = -1 ;
2052+ if (TII->isXDLWMMA (*MI)) {
2053+ for (Category = 0 ; WaitStatesNeeded < 0 && Category < 4 ; Category++) {
2054+ Limit = WMMAWaitStates[Category]; // for IsExpiredFn.
2055+ DenseSet<const MachineBasicBlock *> Visited;
2056+ // '::getWaitStatesSince' returns the number of VALUs in between if hazard
2057+ // exists, and INT_MAX if there is no hazard. As a result, a negative
2058+ // WaitStatesNeeded here means no hazard, and we will continue to search
2059+ // for other categories.
2060+ WaitStatesNeeded =
2061+ Limit - ::getWaitStatesSince (IsWMMAHazardFn, MI->getParent (),
2062+ std::next (MI->getReverseIterator ()), 0 ,
2063+ IsExpiredFn, Visited, GetWaitStatesFn);
2064+ }
2065+ } else { // Must be a co-executable VALU.
2066+ for (Category = 0 ; WaitStatesNeeded < 0 && Category < 4 ; Category++) {
2067+ Limit = VALUWaitStates[Category]; // for IsExpiredFn.
2068+ DenseSet<const MachineBasicBlock *> Visited;
2069+ // '::getWaitStatesSince' returns the number of VALUs in between if hazard
2070+ // exists, and INT_MAX if there is no hazard. As a result, a negative
2071+ // WaitStatesNeeded here means no hazard, and we will continue to search
2072+ // for other categories.
2073+ WaitStatesNeeded =
2074+ Limit - ::getWaitStatesSince (IsVALUHazardFn, MI->getParent (),
2075+ std::next (MI->getReverseIterator ()), 0 ,
2076+ IsExpiredFn, Visited, GetWaitStatesFn);
2077+ }
2078+ }
2079+
2080+ // WaitStatesNeeded now is the number of V_NOPs we need to insert, negative
2081+ // means not needed.
2082+ for (int i = 0 ; i < WaitStatesNeeded; i++)
2083+ BuildMI (*MI->getParent (), MI, MI->getDebugLoc (),
2084+ TII->get (AMDGPU::V_NOP_e32));
2085+
2086+ return true ;
2087+ }
2088+
19122089bool GCNHazardRecognizer::fixShift64HighRegBug (MachineInstr *MI) {
19132090 if (!ST.hasShift64HighRegBug ())
19142091 return false ;
0 commit comments