@@ -525,8 +525,8 @@ static int getWaitStatesSince(GCNHazardRecognizer::IsHazardFn IsHazard,
525525 const MachineInstr *MI, IsExpiredFn IsExpired) {
526526 DenseSet<const MachineBasicBlock *> Visited;
527527 return getWaitStatesSince (IsHazard, MI->getParent (),
528- std::next (MI->getReverseIterator ()),
529- 0 , IsExpired, Visited );
528+ std::next (MI->getReverseIterator ()), 0 , IsExpired,
529+ Visited, SIInstrInfo::getNumWaitStates );
530530}
531531
532532int GCNHazardRecognizer::getWaitStatesSince (IsHazardFn IsHazard, int Limit) {
@@ -1195,7 +1195,8 @@ void GCNHazardRecognizer::fixHazards(MachineInstr *MI) {
11951195 fixVALUPartialForwardingHazard (MI);
11961196 fixVALUTransUseHazard (MI);
11971197 fixVALUTransCoexecutionHazards (MI);
1198- fixWMMAHazards (MI);
1198+ fixWMMAHazards (MI); // fall-through if co-execution is enabled.
1199+ fixWMMACoexecutionHazards (MI);
11991200 fixShift64HighRegBug (MI);
12001201 fixVALUMaskWriteHazard (MI);
12011202 fixRequiredExportPriority (MI);
@@ -1925,6 +1926,182 @@ bool GCNHazardRecognizer::fixWMMAHazards(MachineInstr *MI) {
19251926 return true ;
19261927}
19271928
1929+ static bool isCoexecutableVALUInst (const MachineInstr &MI) {
1930+ return SIInstrInfo::isVALU (MI) && !SIInstrInfo::isTRANS (MI) &&
1931+ !SIInstrInfo::isWMMA (MI) && !SIInstrInfo::isSWMMAC (MI); // What else?
1932+ }
1933+
1934+ static bool IsWMMAHazardInstInCategory (const MachineInstr &MI,
1935+ const SIInstrInfo *TII, unsigned Latency,
1936+ unsigned Category) {
1937+ assert (TII->isXDLWMMA (MI) && (Latency == 8 || Latency == 16 ) &&
1938+ " Handle me if the xdl wmma instruction latency changes" );
1939+
1940+ switch (Category) {
1941+ case 0 : // Dense WMMA Instructions:
1942+ // WMMA_*F16, WMMA_*BF16
1943+ // WMMA_*FP8FP8
1944+ // WMMA_*FP8BF8
1945+ // WMMA_*BF8FP8
1946+ // WMMA_*BF8BF8
1947+ // WMMA_*F8F6F4 if SRCA & SRCB != F8
1948+ return Latency == 8 && SIInstrInfo::isWMMA (MI);
1949+
1950+ case 1 : // Dense WMMA Instructions:
1951+ // WMMA_IU8
1952+ // WMMA_IU4
1953+ // WMMA_*F8F6F4 if SRCA OR SRCB == F8
1954+ return Latency == 16 && SIInstrInfo::isWMMA (MI);
1955+
1956+ case 2 : // Dense SWMMAC Instructions
1957+ // SWMMAC_*F16, SWMMAC_*BF16,
1958+ // SWMMAC_*FP8FP8
1959+ // SWMMAC_*BF8FP8
1960+ // SWMMAC_*FP8BF8
1961+ // SWMMAC_*BF8BF8
1962+ return Latency == 8 && SIInstrInfo::isSWMMAC (MI);
1963+
1964+ case 3 : // Sparse WMMA Instructions:
1965+ // SWMMAC_IU8
1966+ // SWMMAC_IU4
1967+ return Latency == 16 && SIInstrInfo::isSWMMAC (MI);
1968+ default :
1969+ break ;
1970+ } // end switch.
1971+
1972+ return false ;
1973+ }
1974+
1975+ bool GCNHazardRecognizer::fixWMMACoexecutionHazards (MachineInstr *MI) {
1976+ if (!AMDGPU::isGFX1250 (ST))
1977+ return false ;
1978+
1979+ const SIInstrInfo *TII = ST.getInstrInfo ();
1980+ if (!TII->isXDLWMMA (*MI) && !isCoexecutableVALUInst (*MI))
1981+ return false ;
1982+
1983+ const SIRegisterInfo *TRI = ST.getRegisterInfo ();
1984+
1985+ // WaitStates here is the number of V_NOPs or unrelated VALU instructions must
1986+ // be in between the first WMMA and the second instruction to cover the hazard
1987+ // (WMMAWaitStates if the second is also a WMMA, VALUWaitStates if the second
1988+ // is a VALU). Refer to SPG 4.6.12.1. "Requirements for WMMA data hazards" for
1989+ // numbers, which depends on the category of the first WMMA.
1990+ const int WMMAWaitStates[] = {5 , 9 , 3 , 5 };
1991+ const int VALUWaitStates[] = {4 , 8 , 2 , 4 };
1992+ unsigned Category = 0 ;
1993+
1994+ auto IsWMMAHazardFn = [MI, TII, TRI, &Category, this ](const MachineInstr &I) {
1995+ if (!TII->isXDLWMMA (I))
1996+ return false ;
1997+
1998+ unsigned Latency = TSchedModel.computeInstrLatency (&I);
1999+ if (!IsWMMAHazardInstInCategory (I, TII, Latency, Category))
2000+ return false ;
2001+
2002+ Register D0 = TII->getNamedOperand (I, AMDGPU::OpName::vdst)->getReg ();
2003+ Register A1 = TII->getNamedOperand (*MI, AMDGPU::OpName::src0)->getReg ();
2004+ Register B1 = TII->getNamedOperand (*MI, AMDGPU::OpName::src1)->getReg ();
2005+
2006+ // WMMA0 wrires (D0), WMMA1 reads (A1/B1/Idx1).
2007+ if (TRI->regsOverlap (D0, A1) || TRI->regsOverlap (D0, B1))
2008+ return true ;
2009+
2010+ if (SIInstrInfo::isSWMMAC (*MI)) {
2011+ Register Idx1 = TII->getNamedOperand (*MI, AMDGPU::OpName::src2)->getReg ();
2012+ if (TRI->regsOverlap (D0, Idx1))
2013+ return true ;
2014+ }
2015+
2016+ return false ;
2017+ };
2018+
2019+ auto IsVALUHazardFn = [MI, TII, TRI, &Category, this ](const MachineInstr &I) {
2020+ if (!TII->isXDLWMMA (I))
2021+ return false ;
2022+
2023+ unsigned Latency = TSchedModel.computeInstrLatency (&I);
2024+ if (!IsWMMAHazardInstInCategory (I, TII, Latency, Category))
2025+ return false ;
2026+
2027+ // WMMA writes, VALU reads.
2028+ Register D0 = TII->getNamedOperand (I, AMDGPU::OpName::vdst)->getReg ();
2029+ for (const MachineOperand &ValuUse : MI->explicit_uses ()) {
2030+ if (ValuUse.isReg () && TRI->regsOverlap (D0, ValuUse.getReg ()))
2031+ return true ;
2032+ }
2033+
2034+ auto *ValuDst = TII->getNamedOperand (*MI, AMDGPU::OpName::vdst);
2035+ if (!ValuDst || !ValuDst->isReg ())
2036+ return false ;
2037+ Register D1 = ValuDst->getReg ();
2038+
2039+ // WMMA writes, VALU writes.
2040+ if (TRI->regsOverlap (D0, D1))
2041+ return true ;
2042+
2043+ // WMMA reads, VALU writes.
2044+ Register A0 = TII->getNamedOperand (I, AMDGPU::OpName::src0)->getReg ();
2045+ Register B0 = TII->getNamedOperand (I, AMDGPU::OpName::src1)->getReg ();
2046+ if (TRI->regsOverlap (A0, D1) || TRI->regsOverlap (B0, D1))
2047+ return true ;
2048+
2049+ if (SIInstrInfo::isSWMMAC (I)) {
2050+ Register Idx0 = TII->getNamedOperand (I, AMDGPU::OpName::src2)->getReg ();
2051+ if (TRI->regsOverlap (D1, Idx0))
2052+ return true ;
2053+ }
2054+
2055+ return false ;
2056+ };
2057+
2058+ int Limit = 0 ;
2059+ auto IsExpiredFn = [&Limit](const MachineInstr &, int WaitStates) {
2060+ return WaitStates >= Limit;
2061+ };
2062+
2063+ auto GetWaitStatesFn = [](const MachineInstr &I) {
2064+ return SIInstrInfo::isVALU (I) ? 1 : 0 ;
2065+ };
2066+
2067+ int WaitStatesNeeded = -1 ;
2068+ if (TII->isXDLWMMA (*MI)) {
2069+ for (Category = 0 ; WaitStatesNeeded < 0 && Category < 4 ; Category++) {
2070+ Limit = WMMAWaitStates[Category]; // for IsExpiredFn.
2071+ DenseSet<const MachineBasicBlock *> Visited;
2072+ // '::getWaitStatesSince' returns the number of VALUs in between if hazard
2073+ // exists, and INT_MAX if there is no hazard. As a result, a negative
2074+ // WaitStatesNeeded here means no hazard, and we will continue to search
2075+ // for other categories.
2076+ WaitStatesNeeded =
2077+ Limit - ::getWaitStatesSince (IsWMMAHazardFn, MI->getParent (),
2078+ std::next (MI->getReverseIterator ()), 0 ,
2079+ IsExpiredFn, Visited, GetWaitStatesFn);
2080+ }
2081+ } else { // Must be a co-executable VALU.
2082+ for (Category = 0 ; WaitStatesNeeded < 0 && Category < 4 ; Category++) {
2083+ Limit = VALUWaitStates[Category]; // for IsExpiredFn.
2084+ DenseSet<const MachineBasicBlock *> Visited;
2085+ // '::getWaitStatesSince' returns the number of VALUs in between if hazard
2086+ // exists, and INT_MAX if there is no hazard. As a result, a negative
2087+ // WaitStatesNeeded here means no hazard, and we will continue to search
2088+ // for other categories.
2089+ WaitStatesNeeded =
2090+ Limit - ::getWaitStatesSince (IsVALUHazardFn, MI->getParent (),
2091+ std::next (MI->getReverseIterator ()), 0 ,
2092+ IsExpiredFn, Visited, GetWaitStatesFn);
2093+ }
2094+ }
2095+
2096+ // WaitStatesNeeded now is the number of V_NOPs we need to insert, negative
2097+ // means not needed.
2098+ for (int i = 0 ; i < WaitStatesNeeded; i++)
2099+ BuildMI (*MI->getParent (), MI, MI->getDebugLoc (),
2100+ TII->get (AMDGPU::V_NOP_e32));
2101+
2102+ return true ;
2103+ }
2104+
19282105bool GCNHazardRecognizer::fixShift64HighRegBug (MachineInstr *MI) {
19292106 if (!ST.hasShift64HighRegBug ())
19302107 return false ;
0 commit comments